Skip to content

Commit 86bb54c

Browse files
authored
Add support for DNS rebinding protections (#861)
1 parent a2f8766 commit 86bb54c

File tree

10 files changed

+799
-13
lines changed

10 files changed

+799
-13
lines changed

src/mcp/server/fastmcp/server.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from mcp.server.stdio import stdio_server
5151
from mcp.server.streamable_http import EventStore
5252
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
53+
from mcp.server.transport_security import TransportSecuritySettings
5354
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
5455
from mcp.types import (
5556
AnyFunction,
@@ -118,6 +119,9 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
118119

119120
auth: AuthSettings | None = None
120121

122+
# Transport security settings (DNS rebinding protection)
123+
transport_security: TransportSecuritySettings | None = None
124+
121125

122126
def lifespan_wrapper(
123127
app: FastMCP,
@@ -674,6 +678,7 @@ def sse_app(self, mount_path: str | None = None) -> Starlette:
674678

675679
sse = SseServerTransport(
676680
normalized_message_endpoint,
681+
security_settings=self.settings.transport_security,
677682
)
678683

679684
async def handle_sse(scope: Scope, receive: Receive, send: Send):
@@ -779,6 +784,7 @@ def streamable_http_app(self) -> Starlette:
779784
event_store=self._event_store,
780785
json_response=self.settings.json_response,
781786
stateless=self.settings.stateless_http, # Use the stateless setting
787+
security_settings=self.settings.transport_security,
782788
)
783789

784790
# Create the ASGI handler

src/mcp/server/sse.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ async def handle_sse(request):
5252
from starlette.types import Receive, Scope, Send
5353

5454
import mcp.types as types
55+
from mcp.server.transport_security import (
56+
TransportSecurityMiddleware,
57+
TransportSecuritySettings,
58+
)
5559
from mcp.shared.message import ServerMessageMetadata, SessionMessage
5660

5761
logger = logging.getLogger(__name__)
@@ -71,16 +75,22 @@ class SseServerTransport:
7175

7276
_endpoint: str
7377
_read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]]
78+
_security: TransportSecurityMiddleware
7479

75-
def __init__(self, endpoint: str) -> None:
80+
def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None:
7681
"""
7782
Creates a new SSE server transport, which will direct the client to POST
7883
messages to the relative or absolute URL given.
84+
85+
Args:
86+
endpoint: The relative or absolute URL for POST messages.
87+
security_settings: Optional security settings for DNS rebinding protection.
7988
"""
8089

8190
super().__init__()
8291
self._endpoint = endpoint
8392
self._read_stream_writers = {}
93+
self._security = TransportSecurityMiddleware(security_settings)
8494
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
8595

8696
@asynccontextmanager
@@ -89,6 +99,13 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
8999
logger.error("connect_sse received non-HTTP request")
90100
raise ValueError("connect_sse can only handle HTTP requests")
91101

102+
# Validate request headers for DNS rebinding protection
103+
request = Request(scope, receive)
104+
error_response = await self._security.validate_request(request, is_post=False)
105+
if error_response:
106+
await error_response(scope, receive, send)
107+
raise ValueError("Request validation failed")
108+
92109
logger.debug("Setting up SSE connection")
93110
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
94111
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
@@ -160,6 +177,11 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send)
160177
logger.debug("Handling POST message")
161178
request = Request(scope, receive)
162179

180+
# Validate request headers for DNS rebinding protection
181+
error_response = await self._security.validate_request(request, is_post=True)
182+
if error_response:
183+
return await error_response(scope, receive, send)
184+
163185
session_id_param = request.query_params.get("session_id")
164186
if session_id_param is None:
165187
logger.warning("Received request without session_id")

src/mcp/server/streamable_http.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
from starlette.responses import Response
2525
from starlette.types import Receive, Scope, Send
2626

27+
from mcp.server.transport_security import (
28+
TransportSecurityMiddleware,
29+
TransportSecuritySettings,
30+
)
2731
from mcp.shared.message import ServerMessageMetadata, SessionMessage
2832
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
2933
from mcp.types import (
@@ -130,12 +134,14 @@ class StreamableHTTPServerTransport:
130134
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None
131135
_write_stream: MemoryObjectSendStream[SessionMessage] | None = None
132136
_write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None
137+
_security: TransportSecurityMiddleware
133138

134139
def __init__(
135140
self,
136141
mcp_session_id: str | None,
137142
is_json_response_enabled: bool = False,
138143
event_store: EventStore | None = None,
144+
security_settings: TransportSecuritySettings | None = None,
139145
) -> None:
140146
"""
141147
Initialize a new StreamableHTTP server transport.
@@ -148,6 +154,7 @@ def __init__(
148154
event_store: Event store for resumability support. If provided,
149155
resumability will be enabled, allowing clients to
150156
reconnect and resume messages.
157+
security_settings: Optional security settings for DNS rebinding protection.
151158
152159
Raises:
153160
ValueError: If the session ID contains invalid characters.
@@ -158,6 +165,7 @@ def __init__(
158165
self.mcp_session_id = mcp_session_id
159166
self.is_json_response_enabled = is_json_response_enabled
160167
self._event_store = event_store
168+
self._security = TransportSecurityMiddleware(security_settings)
161169
self._request_streams: dict[
162170
RequestId,
163171
tuple[
@@ -251,6 +259,14 @@ async def _clean_up_memory_streams(self, request_id: RequestId) -> None:
251259
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
252260
"""Application entry point that handles all HTTP requests"""
253261
request = Request(scope, receive)
262+
263+
# Validate request headers for DNS rebinding protection
264+
is_post = request.method == "POST"
265+
error_response = await self._security.validate_request(request, is_post=is_post)
266+
if error_response:
267+
await error_response(scope, receive, send)
268+
return
269+
254270
if self._terminated:
255271
# If the session has been terminated, return 404 Not Found
256272
response = self._create_error_response(

src/mcp/server/streamable_http_manager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
EventStore,
2323
StreamableHTTPServerTransport,
2424
)
25+
from mcp.server.transport_security import TransportSecuritySettings
2526

2627
logger = logging.getLogger(__name__)
2728

@@ -60,11 +61,13 @@ def __init__(
6061
event_store: EventStore | None = None,
6162
json_response: bool = False,
6263
stateless: bool = False,
64+
security_settings: TransportSecuritySettings | None = None,
6365
):
6466
self.app = app
6567
self.event_store = event_store
6668
self.json_response = json_response
6769
self.stateless = stateless
70+
self.security_settings = security_settings
6871

6972
# Session tracking (only used if not stateless)
7073
self._session_creation_lock = anyio.Lock()
@@ -162,6 +165,7 @@ async def _handle_stateless_request(
162165
mcp_session_id=None, # No session tracking in stateless mode
163166
is_json_response_enabled=self.json_response,
164167
event_store=None, # No event store in stateless mode
168+
security_settings=self.security_settings,
165169
)
166170

167171
# Start server in a new task
@@ -217,6 +221,7 @@ async def _handle_stateful_request(
217221
mcp_session_id=new_session_id,
218222
is_json_response_enabled=self.json_response,
219223
event_store=self.event_store, # May be None (no resumability)
224+
security_settings=self.security_settings,
220225
)
221226

222227
assert http_transport.mcp_session_id is not None

src/mcp/server/transport_security.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""DNS rebinding protection for MCP server transports."""
2+
3+
import logging
4+
5+
from pydantic import BaseModel, Field
6+
from starlette.requests import Request
7+
from starlette.responses import Response
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class TransportSecuritySettings(BaseModel):
13+
"""Settings for MCP transport security features.
14+
15+
These settings help protect against DNS rebinding attacks by validating
16+
incoming request headers.
17+
"""
18+
19+
enable_dns_rebinding_protection: bool = Field(
20+
default=True,
21+
description="Enable DNS rebinding protection (recommended for production)",
22+
)
23+
24+
allowed_hosts: list[str] = Field(
25+
default=[],
26+
description="List of allowed Host header values. Only applies when "
27+
+ "enable_dns_rebinding_protection is True.",
28+
)
29+
30+
allowed_origins: list[str] = Field(
31+
default=[],
32+
description="List of allowed Origin header values. Only applies when "
33+
+ "enable_dns_rebinding_protection is True.",
34+
)
35+
36+
37+
class TransportSecurityMiddleware:
38+
"""Middleware to enforce DNS rebinding protection for MCP transport endpoints."""
39+
40+
def __init__(self, settings: TransportSecuritySettings | None = None):
41+
# If not specified, disable DNS rebinding protection by default
42+
# for backwards compatibility
43+
self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False)
44+
45+
def _validate_host(self, host: str | None) -> bool:
46+
"""Validate the Host header against allowed values."""
47+
if not host:
48+
logger.warning("Missing Host header in request")
49+
return False
50+
51+
# Check exact match first
52+
if host in self.settings.allowed_hosts:
53+
return True
54+
55+
# Check wildcard port patterns
56+
for allowed in self.settings.allowed_hosts:
57+
if allowed.endswith(":*"):
58+
# Extract base host from pattern
59+
base_host = allowed[:-2]
60+
# Check if the actual host starts with base host and has a port
61+
if host.startswith(base_host + ":"):
62+
return True
63+
64+
logger.warning(f"Invalid Host header: {host}")
65+
return False
66+
67+
def _validate_origin(self, origin: str | None) -> bool:
68+
"""Validate the Origin header against allowed values."""
69+
# Origin can be absent for same-origin requests
70+
if not origin:
71+
return True
72+
73+
# Check exact match first
74+
if origin in self.settings.allowed_origins:
75+
return True
76+
77+
# Check wildcard port patterns
78+
for allowed in self.settings.allowed_origins:
79+
if allowed.endswith(":*"):
80+
# Extract base origin from pattern
81+
base_origin = allowed[:-2]
82+
# Check if the actual origin starts with base origin and has a port
83+
if origin.startswith(base_origin + ":"):
84+
return True
85+
86+
logger.warning(f"Invalid Origin header: {origin}")
87+
return False
88+
89+
def _validate_content_type(self, content_type: str | None) -> bool:
90+
"""Validate the Content-Type header for POST requests."""
91+
if not content_type:
92+
logger.warning("Missing Content-Type header in POST request")
93+
return False
94+
95+
# Content-Type must start with application/json
96+
if not content_type.lower().startswith("application/json"):
97+
logger.warning(f"Invalid Content-Type header: {content_type}")
98+
return False
99+
100+
return True
101+
102+
async def validate_request(self, request: Request, is_post: bool = False) -> Response | None:
103+
"""Validate request headers for DNS rebinding protection.
104+
105+
Returns None if validation passes, or an error Response if validation fails.
106+
"""
107+
# Always validate Content-Type for POST requests
108+
if is_post:
109+
content_type = request.headers.get("content-type")
110+
if not self._validate_content_type(content_type):
111+
return Response("Invalid Content-Type header", status_code=400)
112+
113+
# Skip remaining validation if DNS rebinding protection is disabled
114+
if not self.settings.enable_dns_rebinding_protection:
115+
return None
116+
117+
# Validate Host header
118+
host = request.headers.get("host")
119+
if not self._validate_host(host):
120+
return Response("Invalid Host header", status_code=421)
121+
122+
# Validate Origin header
123+
origin = request.headers.get("origin")
124+
if not self._validate_origin(origin):
125+
return Response("Invalid Origin header", status_code=400)
126+
127+
return None

tests/server/fastmcp/test_integration.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from mcp.client.streamable_http import streamablehttp_client
2424
from mcp.server.fastmcp import Context, FastMCP
2525
from mcp.server.fastmcp.resources import FunctionResource
26+
from mcp.server.transport_security import TransportSecuritySettings
2627
from mcp.shared.context import RequestContext
2728
from mcp.types import (
2829
Completion,
@@ -92,7 +93,10 @@ def stateless_http_server_url(stateless_http_server_port: int) -> str:
9293
# Create a function to make the FastMCP server app
9394
def make_fastmcp_app():
9495
"""Create a FastMCP server without auth settings."""
95-
mcp = FastMCP(name="NoAuthServer")
96+
transport_security = TransportSecuritySettings(
97+
allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"]
98+
)
99+
mcp = FastMCP(name="NoAuthServer", transport_security=transport_security)
96100

97101
# Add a simple tool
98102
@mcp.tool(description="A simple echo tool")
@@ -121,9 +125,10 @@ class AnswerSchema(BaseModel):
121125

122126
def make_everything_fastmcp() -> FastMCP:
123127
"""Create a FastMCP server with all features enabled for testing."""
124-
from mcp.server.fastmcp import Context
125-
126-
mcp = FastMCP(name="EverythingServer")
128+
transport_security = TransportSecuritySettings(
129+
allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"]
130+
)
131+
mcp = FastMCP(name="EverythingServer", transport_security=transport_security)
127132

128133
# Tool with context for logging and progress
129134
@mcp.tool(description="A tool that demonstrates logging and progress", title="Progress Tool")
@@ -333,8 +338,10 @@ def make_everything_fastmcp_app():
333338

334339
def make_fastmcp_streamable_http_app():
335340
"""Create a FastMCP server with StreamableHTTP transport."""
336-
337-
mcp = FastMCP(name="NoAuthServer")
341+
transport_security = TransportSecuritySettings(
342+
allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"]
343+
)
344+
mcp = FastMCP(name="NoAuthServer", transport_security=transport_security)
338345

339346
# Add a simple tool
340347
@mcp.tool(description="A simple echo tool")
@@ -359,8 +366,10 @@ def make_everything_fastmcp_streamable_http_app():
359366

360367
def make_fastmcp_stateless_http_app():
361368
"""Create a FastMCP server with stateless StreamableHTTP transport."""
362-
363-
mcp = FastMCP(name="StatelessServer", stateless_http=True)
369+
transport_security = TransportSecuritySettings(
370+
allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"]
371+
)
372+
mcp = FastMCP(name="StatelessServer", stateless_http=True, transport_security=transport_security)
364373

365374
# Add a simple tool
366375
@mcp.tool(description="A simple echo tool")

0 commit comments

Comments
 (0)