Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,16 @@ For Windsurf, the format in `mcp_config.json` is slightly different:
}
```

For local integration with your browser using, for example, [MCP for claude.ai](https://chromewebstore.google.com/detail/jbdhaamjibfahpekpnjeikanebpdpfpb?utm_source=item-share-cb), you may need to allow certain CORS origins, such as https://claude.ai. To do this, start the server with the `--cors-origins` parameter and provide the list of origins you want to whitelist.

For example, with Docker run:

```bash
docker run -p 8000:8000 \
-e DATABASE_URI=postgresql://username:password@localhost:5432/dbname \
crystaldba/postgres-mcp --access-mode=unrestricted --transport=sse --cors-origins https://claude.ai
```

## Postgres Extension Installation (Optional)

To enable index tuning and comprehensive performance analysis you need to load the `pg_statements` and `hypopg` extensions on your database.
Expand Down
31 changes: 27 additions & 4 deletions src/postgres_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
from typing import Union

import mcp.types as types
import uvicorn
from mcp.server.fastmcp import FastMCP
from pydantic import Field
from pydantic import validate_call
from starlette.middleware.cors import CORSMiddleware

from postgres_mcp.index.dta_calc import DatabaseTuningAdvisor

Expand Down Expand Up @@ -539,6 +541,12 @@ async def main():
default=8000,
help="Port for SSE server (default: 8000)",
)
parser.add_argument(
"--cors-origins",
nargs="*",
default=[],
help="List of allowed CORS origins (default: empty, no CORS)",
)

args = parser.parse_args()

Expand Down Expand Up @@ -589,10 +597,25 @@ async def main():
if args.transport == "stdio":
await mcp.run_stdio_async()
else:
# Update FastMCP settings based on command line arguments
mcp.settings.host = args.sse_host
mcp.settings.port = args.sse_port
await mcp.run_sse_async()
starlette_app = mcp.sse_app()

if args.cors_origins:
logger.info(f"Enabling CORS for origins: {", ".join(args.cors_origins)}")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Fix f-string syntax that breaks server startup

This logger.info line has an invalid f-string because the inner ", " string terminates the outer double-quoted f-string, which makes the module fail to parse at import time. In any environment that executes this code path (including postgres-mcp --transport=sse), Python will raise a SyntaxError before the server can start. Use single quotes inside the join or escape the quotes to keep the f-string valid.

Useful? React with 👍 / 👎.

starlette_app.add_middleware(
CORSMiddleware,
allow_origins=args.cors_origins,
allow_methods=['GET', 'POST', 'OPTIONS'],
allow_headers=['*']
)

config = uvicorn.Config(
starlette_app,
host=args.sse_host,
port=args.sse_port,
log_level="info",
)
server = uvicorn.Server(config)
await server.serve()


async def shutdown(sig=None):
Expand Down
136 changes: 136 additions & 0 deletions tests/unit/test_cors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""Tests for CORS support in SSE transport."""

import pytest
from starlette.middleware.cors import CORSMiddleware
from starlette.testclient import TestClient

from postgres_mcp.server import mcp


@pytest.fixture
def app_with_cors():
"""Create an SSE app with CORS middleware configured."""
app = mcp.sse_app()
app.add_middleware(
CORSMiddleware,
allow_origins=["https://claude.ai", "https://example.com"],
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["*"],
)
return app


@pytest.fixture
def app_without_cors():
"""Create an SSE app without CORS middleware."""
return mcp.sse_app()


class TestCorsPreflightRequests:
"""Test CORS preflight (OPTIONS) requests."""

def test_preflight_allowed_origin_returns_cors_headers(self, app_with_cors):
"""OPTIONS preflight from allowed origin should return CORS headers."""
client = TestClient(app_with_cors, raise_server_exceptions=False)
response = client.options(
"/sse",
headers={
"Origin": "https://claude.ai",
"Access-Control-Request-Method": "GET",
},
)
assert response.status_code == 200
assert response.headers.get("access-control-allow-origin") == "https://claude.ai"
assert "GET" in response.headers.get("access-control-allow-methods", "")

def test_preflight_second_allowed_origin(self, app_with_cors):
"""OPTIONS preflight from second allowed origin should also work."""
client = TestClient(app_with_cors, raise_server_exceptions=False)
response = client.options(
"/sse",
headers={
"Origin": "https://example.com",
"Access-Control-Request-Method": "GET",
},
)
assert response.status_code == 200
assert response.headers.get("access-control-allow-origin") == "https://example.com"

def test_preflight_disallowed_origin_no_cors_header(self, app_with_cors):
"""OPTIONS preflight from non-allowed origin should not return CORS header."""
client = TestClient(app_with_cors, raise_server_exceptions=False)
response = client.options(
"/sse",
headers={
"Origin": "https://malicious.com",
"Access-Control-Request-Method": "GET",
},
)
# The response may be 200 or 400, but should NOT have the allow-origin header
assert response.headers.get("access-control-allow-origin") is None

def test_preflight_messages_endpoint(self, app_with_cors):
"""OPTIONS preflight on /messages/ endpoint should also work."""
client = TestClient(app_with_cors, raise_server_exceptions=False)
response = client.options(
"/messages/",
headers={
"Origin": "https://claude.ai",
"Access-Control-Request-Method": "POST",
},
)
assert response.status_code == 200
assert response.headers.get("access-control-allow-origin") == "https://claude.ai"
assert "POST" in response.headers.get("access-control-allow-methods", "")


class TestCorsOnActualRequests:
"""Test CORS headers on actual (non-preflight) requests."""

def test_post_request_with_allowed_origin(self, app_with_cors):
"""POST request from allowed origin should include CORS header in response."""
client = TestClient(app_with_cors, raise_server_exceptions=False)
# Send a POST to /messages/ - it will fail (no valid session) but CORS headers should be present
response = client.post(
"/messages/",
headers={"Origin": "https://claude.ai"},
content="test",
)
# Even if the request fails, CORS headers should be present
assert response.headers.get("access-control-allow-origin") == "https://claude.ai"

def test_post_request_with_disallowed_origin(self, app_with_cors):
"""POST request from non-allowed origin should not have CORS header."""
client = TestClient(app_with_cors, raise_server_exceptions=False)
response = client.post(
"/messages/",
headers={"Origin": "https://malicious.com"},
content="test",
)
assert response.headers.get("access-control-allow-origin") is None


class TestCorsDisabled:
"""Test behavior when CORS middleware is not configured."""

def test_preflight_without_cors_middleware(self, app_without_cors):
"""App without CORS middleware should not handle preflight specially."""
client = TestClient(app_without_cors, raise_server_exceptions=False)
response = client.options(
"/sse",
headers={
"Origin": "https://claude.ai",
"Access-Control-Request-Method": "GET",
},
)
assert response.headers.get("access-control-allow-origin") is None

def test_request_without_cors_middleware(self, app_without_cors):
"""App without CORS middleware should not return CORS headers."""
client = TestClient(app_without_cors, raise_server_exceptions=False)
response = client.post(
"/messages/",
headers={"Origin": "https://claude.ai"},
content="test",
)
assert response.headers.get("access-control-allow-origin") is None