diff --git a/README.md b/README.md index d5005ed6..5939737c 100644 --- a/README.md +++ b/README.md @@ -227,6 +227,48 @@ Postgres MCP Pro supports multiple *access modes* to give you control over the o To use restricted mode, replace `--access-mode=unrestricted` with `--access-mode=restricted` in the configuration examples above. +##### Transport Security Configuration + +Postgres MCP Pro includes DNS rebinding protection to secure the server against certain types of attacks. +By default, the server allows connections from common local and Docker hostnames. +Transport security applies only to network transports (`sse` and `streamable-http`), not `stdio`. + +You can customize this behavior using CLI flags or environment variables (env vars take precedence over CLI flags): + +| CLI Flag | Environment Variable | Description | Default | +|---|---|---|---| +| `--disable-dns-rebinding-protection` | `MCP_ENABLE_DNS_REBINDING_PROTECTION` | Enable/disable DNS rebinding protection | Enabled | +| `--allowed-hosts` | `MCP_ALLOWED_HOSTS` | Comma-separated allowed host patterns | `localhost:*,127.0.0.1:*,0.0.0.0:*,postgres-mcp-server:*,host.docker.internal:*` | +| `--allowed-origins` | `MCP_ALLOWED_ORIGINS` | Comma-separated allowed origins | Empty (allows any origin) | + +For example, to restrict allowed hosts in your configuration: + +```json +{ + "mcpServers": { + "postgres": { + "command": "docker", + "args": [ + "run", + "-i", + "--rm", + "-e", + "DATABASE_URI", + "-e", + "MCP_ALLOWED_HOSTS", + "crystaldba/postgres-mcp", + "--access-mode=unrestricted" + ], + "env": { + "DATABASE_URI": "postgresql://username:password@localhost:5432/dbname", + "MCP_ALLOWED_HOSTS": "localhost:*,myapp.example.com:*" + } + } + } +} +``` + + #### Other MCP Clients Many MCP clients have similar configuration files to Claude Desktop, and you can adapt the examples above to work with the client of your choice. diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index f3ba8f8b..ad407804 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -13,6 +13,7 @@ import mcp.types as types from mcp.server.fastmcp import FastMCP +from mcp.server.transport_security import TransportSecuritySettings from mcp.types import ToolAnnotations from pydantic import Field from pydantic import validate_call @@ -596,6 +597,24 @@ async def main(): default=8000, help="Port for streamable HTTP server (default: 8000)", ) + parser.add_argument( + "--disable-dns-rebinding-protection", + action="store_true", + default=False, + help="Disable DNS rebinding protection (not recommended for production)", + ) + parser.add_argument( + "--allowed-hosts", + type=str, + default=None, + help="Comma-separated allowed Host header values for DNS rebinding protection (e.g. 'localhost:*,127.0.0.1:*')", + ) + parser.add_argument( + "--allowed-origins", + type=str, + default=None, + help="Comma-separated allowed Origin header values for DNS rebinding protection (e.g. 'http://localhost:*')", + ) args = parser.parse_args() @@ -656,6 +675,20 @@ async def main(): logger.warning("Signal handling not supported on Windows") pass + # Apply transport security settings (SSE and streamable-http only) + if args.transport in ("sse", "streamable-http"): + dns_env = os.environ.get("MCP_ENABLE_DNS_REBINDING_PROTECTION") + protection_off = dns_env.lower() in ("false", "0", "no") if dns_env else args.disable_dns_rebinding_protection + hosts = os.environ.get("MCP_ALLOWED_HOSTS", args.allowed_hosts) + origins = os.environ.get("MCP_ALLOWED_ORIGINS", args.allowed_origins) + + if protection_off or hosts or origins: + mcp.settings.transport_security = TransportSecuritySettings( + enable_dns_rebinding_protection=not protection_off, + **{"allowed_hosts": [h.strip() for h in hosts.split(",") if h.strip()]} if hosts else {}, + **{"allowed_origins": [o.strip() for o in origins.split(",") if o.strip()]} if origins else {}, + ) + # Run the server with the selected transport (always async) if args.transport == "stdio": await mcp.run_stdio_async() diff --git a/tests/unit/test_transport_security.py b/tests/unit/test_transport_security.py new file mode 100644 index 00000000..a54ae666 --- /dev/null +++ b/tests/unit/test_transport_security.py @@ -0,0 +1,224 @@ +import sys +from unittest.mock import AsyncMock +from unittest.mock import patch + +import pytest + +_TRANSPORT_MOCK_MAP = { + "sse": "postgres_mcp.server.mcp.run_sse_async", + "streamable-http": "postgres_mcp.server.mcp.run_streamable_http_async", +} + +_MCP_ENV_KEYS = [ + "MCP_ENABLE_DNS_REBINDING_PROTECTION", + "MCP_ALLOWED_HOSTS", + "MCP_ALLOWED_ORIGINS", +] + + +@pytest.mark.parametrize("transport", ["sse", "streamable-http"]) +class TestTransportSecurityIntegration: + @pytest.fixture(autouse=True) + def _preserve_mcp_state(self, monkeypatch: pytest.MonkeyPatch): + from postgres_mcp.server import mcp + + original_argv = sys.argv + original_security = mcp.settings.transport_security + for key in _MCP_ENV_KEYS: + monkeypatch.delenv(key, raising=False) + yield + sys.argv = original_argv + mcp.settings.transport_security = original_security + + @pytest.mark.asyncio + async def test_disable_dns_rebinding_via_cli_flag(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + "--disable-dns-rebinding-protection", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + ): + await main() + assert mcp.settings.transport_security is not None + assert mcp.settings.transport_security.enable_dns_rebinding_protection is False + + @pytest.mark.asyncio + async def test_disable_dns_rebinding_via_env(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + patch.dict("os.environ", {"MCP_ENABLE_DNS_REBINDING_PROTECTION": "false"}), + ): + await main() + assert mcp.settings.transport_security is not None + assert mcp.settings.transport_security.enable_dns_rebinding_protection is False + + @pytest.mark.asyncio + async def test_allowed_hosts_via_cli(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + "--allowed-hosts", + "localhost:*,127.0.0.1:*", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + ): + await main() + assert mcp.settings.transport_security is not None + assert "localhost:*" in mcp.settings.transport_security.allowed_hosts + assert "127.0.0.1:*" in mcp.settings.transport_security.allowed_hosts + + @pytest.mark.asyncio + async def test_allowed_hosts_env_overrides_cli(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + "--allowed-hosts", + "cli-host:*", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + patch.dict("os.environ", {"MCP_ALLOWED_HOSTS": "env-host:*"}), + ): + await main() + assert mcp.settings.transport_security is not None + assert "env-host:*" in mcp.settings.transport_security.allowed_hosts + assert "cli-host:*" not in mcp.settings.transport_security.allowed_hosts + + @pytest.mark.asyncio + async def test_allowed_origins_via_cli(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + "--allowed-origins", + "http://localhost:*", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + ): + await main() + assert mcp.settings.transport_security is not None + assert "http://localhost:*" in mcp.settings.transport_security.allowed_origins + + @pytest.mark.asyncio + async def test_allowed_origins_env_overrides_cli(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + "--allowed-origins", + "http://cli-origin:*", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + patch.dict("os.environ", {"MCP_ALLOWED_ORIGINS": "http://env-origin:*"}), + ): + await main() + assert mcp.settings.transport_security is not None + assert "http://env-origin:*" in mcp.settings.transport_security.allowed_origins + assert "http://cli-origin:*" not in mcp.settings.transport_security.allowed_origins + + @pytest.mark.asyncio + async def test_env_protection_true_overrides_cli_disable(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + "--disable-dns-rebinding-protection", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + patch.dict("os.environ", {"MCP_ENABLE_DNS_REBINDING_PROTECTION": "true"}), + ): + await main() + assert mcp.settings.transport_security is not None + assert mcp.settings.transport_security.enable_dns_rebinding_protection is True + + @pytest.mark.asyncio + async def test_default_defers_to_fastmcp(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + ): + await main() + assert mcp.settings.transport_security is not None + assert mcp.settings.transport_security.enable_dns_rebinding_protection is True + + @pytest.mark.asyncio + async def test_database_url_after_flags_not_consumed(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + f"--transport={transport}", + "--allowed-hosts", + "localhost:*,my-gateway:8080", + "--allowed-origins", + "http://localhost:*,http://my-gateway:*", + "postgresql://user:password@localhost/db", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + ): + await main() + assert mcp.settings.transport_security is not None + assert "localhost:*" in mcp.settings.transport_security.allowed_hosts + assert "my-gateway:8080" in mcp.settings.transport_security.allowed_hosts