Skip to content

Commit 5dd1d59

Browse files
committed
fix: Fix logic that determines standard resource vs. resource template to account for context param
1 parent 2fe56e5 commit 5dd1d59

File tree

11 files changed

+280
-58
lines changed

11 files changed

+280
-58
lines changed

src/mcp/server/mcpserver/resources/base.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Base classes and interfaces for MCPServer resources."""
22

3+
from __future__ import annotations
4+
35
import abc
4-
from typing import Any
6+
from typing import TYPE_CHECKING, Any
57

68
from pydantic import (
79
BaseModel,
@@ -13,6 +15,10 @@
1315

1416
from mcp.types import Annotations, Icon
1517

18+
if TYPE_CHECKING:
19+
from mcp.server.context import LifespanContextT, RequestT
20+
from mcp.server.mcpserver.server import Context
21+
1622

1723
class Resource(BaseModel, abc.ABC):
1824
"""Base class for all resources."""
@@ -43,6 +49,9 @@ def set_default_name(cls, name: str | None, info: ValidationInfo) -> str:
4349
raise ValueError("Either name or uri must be provided")
4450

4551
@abc.abstractmethod
46-
async def read(self) -> str | bytes:
52+
async def read(
53+
self,
54+
context: Context[LifespanContextT, RequestT] | None = None,
55+
) -> str | bytes:
4756
"""Read the resource content."""
4857
pass # pragma: no cover

src/mcp/server/mcpserver/resources/templates.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,16 @@
22

33
from __future__ import annotations
44

5-
import inspect
65
import re
76
from collections.abc import Callable
87
from typing import TYPE_CHECKING, Any
98
from urllib.parse import unquote
109

11-
from pydantic import BaseModel, Field, validate_call
10+
from pydantic import BaseModel, Field
1211

1312
from mcp.server.mcpserver.resources.types import FunctionResource, Resource
14-
from mcp.server.mcpserver.utilities.context_injection import find_context_parameter, inject_context
15-
from mcp.server.mcpserver.utilities.func_metadata import func_metadata
13+
from mcp.server.mcpserver.utilities.context_injection import find_context_parameter
14+
from mcp.server.mcpserver.utilities.func_metadata import FuncMetadata, func_metadata, is_async_callable
1615
from mcp.types import Annotations, Icon
1716

1817
if TYPE_CHECKING:
@@ -34,6 +33,10 @@ class ResourceTemplate(BaseModel):
3433
fn: Callable[..., Any] = Field(exclude=True)
3534
parameters: dict[str, Any] = Field(description="JSON schema for function parameters")
3635
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context")
36+
fn_metadata: FuncMetadata = Field(
37+
description="Metadata about the function including a pydantic model for arguments"
38+
)
39+
is_async: bool = Field(description="Whether the function is async")
3740

3841
@classmethod
3942
def from_function(
@@ -58,16 +61,15 @@ def from_function(
5861
if context_kwarg is None: # pragma: no branch
5962
context_kwarg = find_context_parameter(fn)
6063

64+
is_async = is_async_callable(fn)
65+
6166
# Get schema from func_metadata, excluding context parameter
6267
func_arg_metadata = func_metadata(
6368
fn,
6469
skip_names=[context_kwarg] if context_kwarg is not None else [],
6570
)
6671
parameters = func_arg_metadata.arg_model.model_json_schema()
6772

68-
# ensure the arguments are properly cast
69-
fn = validate_call(fn)
70-
7173
return cls(
7274
uri_template=uri_template,
7375
name=func_name,
@@ -80,6 +82,8 @@ def from_function(
8082
fn=fn,
8183
parameters=parameters,
8284
context_kwarg=context_kwarg,
85+
fn_metadata=func_arg_metadata,
86+
is_async=is_async,
8387
)
8488

8589
def matches(self, uri: str) -> dict[str, Any] | None:
@@ -103,13 +107,12 @@ async def create_resource(
103107
) -> Resource:
104108
"""Create a resource from the template with the given parameters."""
105109
try:
106-
# Add context to params if needed
107-
params = inject_context(self.fn, params, context, self.context_kwarg)
108-
109-
# Call function and check if result is a coroutine
110-
result = self.fn(**params)
111-
if inspect.iscoroutine(result):
112-
result = await result
110+
result = await self.fn_metadata.call_fn_with_arg_validation(
111+
self.fn,
112+
self.is_async,
113+
params,
114+
{self.context_kwarg: context} if self.context_kwarg is not None else None,
115+
)
113116

114117
return FunctionResource(
115118
uri=uri, # type: ignore

src/mcp/server/mcpserver/resources/types.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,38 @@
11
"""Concrete resource implementations."""
22

3+
from __future__ import annotations
4+
35
import inspect
46
import json
57
from collections.abc import Callable
68
from pathlib import Path
7-
from typing import Any
9+
from typing import TYPE_CHECKING, Any
810

911
import anyio
1012
import anyio.to_thread
1113
import httpx
1214
import pydantic
1315
import pydantic_core
14-
from pydantic import Field, ValidationInfo, validate_call
16+
from pydantic import Field, ValidationInfo
1517

1618
from mcp.server.mcpserver.resources.base import Resource
19+
from mcp.server.mcpserver.utilities.context_injection import find_context_parameter
1720
from mcp.types import Annotations, Icon
1821

22+
if TYPE_CHECKING:
23+
from mcp.server.context import LifespanContextT, RequestT
24+
from mcp.server.mcpserver.server import Context
25+
1926

2027
class TextResource(Resource):
2128
"""A resource that reads from a string."""
2229

2330
text: str = Field(description="Text content of the resource")
2431

25-
async def read(self) -> str:
32+
async def read(
33+
self,
34+
context: Context[LifespanContextT, RequestT] | None = None,
35+
) -> str:
2636
"""Read the text content."""
2737
return self.text # pragma: no cover
2838

@@ -32,7 +42,10 @@ class BinaryResource(Resource):
3242

3343
data: bytes = Field(description="Binary content of the resource")
3444

35-
async def read(self) -> bytes:
45+
async def read(
46+
self,
47+
context: Context[LifespanContextT, RequestT] | None = None,
48+
) -> bytes:
3649
"""Read the binary content."""
3750
return self.data # pragma: no cover
3851

@@ -50,13 +63,22 @@ class FunctionResource(Resource):
5063
- other types will be converted to JSON
5164
"""
5265

53-
fn: Callable[[], Any] = Field(exclude=True)
66+
fn: Callable[..., Any] = Field(exclude=True)
67+
context_kwarg: str | None = Field(default=None, description="Name of the kwarg that should receive context")
5468

55-
async def read(self) -> str | bytes:
69+
async def read(
70+
self,
71+
context: Context[LifespanContextT, RequestT] | None = None,
72+
) -> str | bytes:
5673
"""Read the resource by calling the wrapped function."""
5774
try:
58-
# Call the function first to see if it returns a coroutine
59-
result = self.fn()
75+
# Inject context if needed
76+
kwargs: dict[str, Any] = {}
77+
if self.context_kwarg is not None and context is not None:
78+
kwargs[self.context_kwarg] = context
79+
80+
# Call the function
81+
result = self.fn(**kwargs)
6082
# If it's a coroutine, await it
6183
if inspect.iscoroutine(result):
6284
result = await result
@@ -84,14 +106,14 @@ def from_function(
84106
icons: list[Icon] | None = None,
85107
annotations: Annotations | None = None,
86108
meta: dict[str, Any] | None = None,
87-
) -> "FunctionResource":
109+
) -> FunctionResource:
88110
"""Create a FunctionResource from a function."""
89111
func_name = name or fn.__name__
90112
if func_name == "<lambda>": # pragma: no cover
91113
raise ValueError("You must provide a name for lambda functions")
92114

93-
# ensure the arguments are properly cast
94-
fn = validate_call(fn)
115+
# Find context parameter if it exists
116+
context_kwarg = find_context_parameter(fn)
95117

96118
return cls(
97119
uri=uri,
@@ -100,6 +122,7 @@ def from_function(
100122
description=description or fn.__doc__ or "",
101123
mime_type=mime_type or "text/plain",
102124
fn=fn,
125+
context_kwarg=context_kwarg,
103126
icons=icons,
104127
annotations=annotations,
105128
meta=meta,
@@ -139,7 +162,10 @@ def set_binary_from_mime_type(cls, is_binary: bool, info: ValidationInfo) -> boo
139162
mime_type = info.data.get("mime_type", "text/plain")
140163
return not mime_type.startswith("text/")
141164

142-
async def read(self) -> str | bytes:
165+
async def read(
166+
self,
167+
context: Context[LifespanContextT, RequestT] | None = None,
168+
) -> str | bytes:
143169
"""Read the file content."""
144170
try:
145171
if self.is_binary:
@@ -155,7 +181,10 @@ class HttpResource(Resource):
155181
url: str = Field(description="URL to fetch content from")
156182
mime_type: str = Field(default="application/json", description="MIME type of the resource content")
157183

158-
async def read(self) -> str | bytes:
184+
async def read(
185+
self,
186+
context: Context[LifespanContextT, RequestT] | None = None,
187+
) -> str | bytes:
159188
"""Read the HTTP content."""
160189
async with httpx.AsyncClient() as client: # pragma: no cover
161190
response = await client.get(self.url)
@@ -193,7 +222,10 @@ def list_files(self) -> list[Path]: # pragma: no cover
193222
except Exception as e:
194223
raise ValueError(f"Error listing directory {self.path}: {e}")
195224

196-
async def read(self) -> str: # Always returns JSON string # pragma: no cover
225+
async def read(
226+
self,
227+
context: Context[LifespanContextT, RequestT] | None = None,
228+
) -> str: # Always returns JSON string # pragma: no cover
197229
"""Read the directory listing."""
198230
try:
199231
files = await anyio.to_thread.run_sync(self.list_files)

src/mcp/server/mcpserver/server.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContent
446446
raise ResourceError(f"Unknown resource: {uri}")
447447

448448
try:
449-
content = await resource.read()
449+
content = await resource.read(context=context)
450450
return [ReadResourceContents(content=content, mime_type=resource.mime_type, meta=resource.meta)]
451451
except Exception as exc:
452452
logger.exception(f"Error getting resource {uri}")
@@ -667,19 +667,16 @@ async def get_weather(city: str) -> str:
667667
def decorator(fn: _CallableT) -> _CallableT:
668668
# Check if this should be a template
669669
sig = inspect.signature(fn)
670-
has_uri_params = "{" in uri and "}" in uri
671-
has_func_params = bool(sig.parameters)
670+
uri_params = set(re.findall(r"{(\w+)}", uri))
671+
context_param = find_context_parameter(fn)
672+
func_params = {p for p in sig.parameters.keys() if p != context_param}
672673

673-
if has_uri_params or has_func_params:
674-
# Check for Context parameter to exclude from validation
675-
context_param = find_context_parameter(fn)
676-
677-
# Validate that URI params match function params (excluding context)
678-
uri_params = set(re.findall(r"{(\w+)}", uri))
679-
# We need to remove the context_param from the resource function if
680-
# there is any.
681-
func_params = {p for p in sig.parameters.keys() if p != context_param}
674+
# Determine if this should be a template
675+
has_uri_params = len(uri_params) != 0
676+
has_func_params = len(func_params) != 0
682677

678+
if has_uri_params or has_func_params:
679+
# Validate that URI params match function params
683680
if uri_params != func_params:
684681
raise ValueError(
685682
f"Mismatch between URI parameters {uri_params} and function parameters {func_params}"

src/mcp/server/mcpserver/tools/base.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from __future__ import annotations
22

3-
import functools
4-
import inspect
53
from collections.abc import Callable
64
from functools import cached_property
75
from typing import TYPE_CHECKING, Any
@@ -10,7 +8,7 @@
108

119
from mcp.server.mcpserver.exceptions import ToolError
1210
from mcp.server.mcpserver.utilities.context_injection import find_context_parameter
13-
from mcp.server.mcpserver.utilities.func_metadata import FuncMetadata, func_metadata
11+
from mcp.server.mcpserver.utilities.func_metadata import FuncMetadata, func_metadata, is_async_callable
1412
from mcp.shared.exceptions import UrlElicitationRequiredError
1513
from mcp.shared.tool_name_validation import validate_and_warn_tool_name
1614
from mcp.types import Icon, ToolAnnotations
@@ -63,7 +61,7 @@ def from_function(
6361
raise ValueError("You must provide a name for lambda functions")
6462

6563
func_doc = description or fn.__doc__ or ""
66-
is_async = _is_async_callable(fn)
64+
is_async = is_async_callable(fn)
6765

6866
if context_kwarg is None: # pragma: no branch
6967
context_kwarg = find_context_parameter(fn)
@@ -114,12 +112,3 @@ async def run(
114112
raise
115113
except Exception as e:
116114
raise ToolError(f"Error executing tool {self.name}: {e}") from e
117-
118-
119-
def _is_async_callable(obj: Any) -> bool:
120-
while isinstance(obj, functools.partial): # pragma: lax no cover
121-
obj = obj.func
122-
123-
return inspect.iscoroutinefunction(obj) or (
124-
callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None))
125-
)

src/mcp/server/mcpserver/utilities/func_metadata.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,3 +523,13 @@ def _convert_to_content(result: Any) -> Sequence[ContentBlock]:
523523
result = pydantic_core.to_json(result, fallback=str, indent=2).decode()
524524

525525
return [TextContent(type="text", text=result)]
526+
527+
528+
def is_async_callable(obj: Any) -> bool:
529+
"""Check if an object is an async callable."""
530+
while isinstance(obj, functools.partial):
531+
obj = obj.func
532+
533+
return inspect.iscoroutinefunction(obj) or (
534+
callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None))
535+
)

0 commit comments

Comments
 (0)