Skip to content

Commit 13f3e8d

Browse files
Add oauth support for JSON-string passed configs
1 parent 7046e62 commit 13f3e8d

4 files changed

Lines changed: 157 additions & 15 deletions

File tree

mcp_compressor/main.py

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,12 @@
2525
import keyring.errors
2626
import psutil
2727
import typer
28+
from click.core import ParameterSource
2829
from cryptography.fernet import Fernet
2930
from fastmcp import FastMCP
3031
from fastmcp.client.auth import OAuth
31-
from fastmcp.client.transports import MCPConfigTransport, SSETransport, StdioTransport, StreamableHttpTransport
32-
from fastmcp.mcp_config import MCPConfig
32+
from fastmcp.client.transports import SSETransport, StdioTransport, StreamableHttpTransport
33+
from fastmcp.mcp_config import MCPConfig, RemoteMCPServer, StdioMCPServer
3334
from fastmcp.server import create_proxy
3435
from fastmcp.server.providers.proxy import ProxyClient
3536
from key_value.aio.protocols import AsyncKeyValue
@@ -70,6 +71,7 @@ def _version_callback(value: bool) -> None:
7071

7172
@app.command()
7273
def main(
74+
ctx: typer.Context,
7375
command_or_url_list: Annotated[
7476
list[str],
7577
typer.Argument(
@@ -223,6 +225,23 @@ def main(
223225
raise typer.BadParameter(str(exc), param_hint="'COMMAND_OR_URL'") from exc
224226

225227
if parsed_config is not None:
228+
conflicting_config_options = [
229+
option_name
230+
for option_name in ("cwd", "env_list", "header_list", "timeout")
231+
if ctx.get_parameter_source(option_name) == ParameterSource.COMMANDLINE
232+
]
233+
if conflicting_config_options:
234+
joined_options = ", ".join(
235+
f"--{name.removesuffix('_list').replace('_', '-')}" for name in conflicting_config_options
236+
)
237+
raise typer.BadParameter(
238+
(
239+
f"JSON MCP config input cannot be combined with {joined_options}; configure those values inside "
240+
"the JSON instead."
241+
),
242+
param_hint="'COMMAND_OR_URL'",
243+
)
244+
226245
_config, config_server_name = parsed_config
227246
resolved_server_name = server_name or config_server_name
228247

@@ -423,17 +442,13 @@ async def _server(
423442
parsed_config = _parse_single_server_mcp_config(command_or_url_list)
424443
if parsed_config is not None:
425444
config, _config_server_name = parsed_config
426-
transport_type: Literal["stdio", "http", "sse"] = _infer_transport_type(
427-
str(next(iter(config.mcpServers.values())).model_dump().get("url") or command_or_url)
428-
)
429-
transport = MCPConfigTransport(config)
445+
transport, transport_type = _get_single_server_transport_from_mcp_config(config=config)
430446
logger.info("Loaded single-server MCP config JSON")
431447
else:
432448
transport_type = _infer_transport_type(command_or_url)
433449
logger.info(f"Inferred transport type: {transport_type}")
434450

435451
# Handle different transport types
436-
transport: TransportType
437452
if transport_type == "stdio":
438453
transport = _get_stdio_transport(
439454
command=command_or_url_list[0], args=command_or_url_list[1:], cwd=cwd, env_list=env_list
@@ -657,6 +672,57 @@ def _parse_single_server_mcp_config(command_or_url_list: list[str]) -> tuple[MCP
657672
return config, server_name
658673

659674

675+
def _get_single_server_transport_from_mcp_config(
676+
config: MCPConfig,
677+
) -> tuple[TransportType, Literal["stdio", "http", "sse"]]:
678+
"""Create a transport for a validated single-server MCP config.
679+
680+
Single-server config JSON is self-contained: transport settings come only from the config.
681+
Remote configs default to the same OAuth flow used by direct URL inputs unless explicit auth is provided.
682+
"""
683+
server_config = next(iter(config.mcpServers.values()))
684+
685+
if isinstance(server_config, StdioMCPServer):
686+
return (
687+
_get_stdio_transport(
688+
command=server_config.command,
689+
args=server_config.args,
690+
cwd=server_config.cwd,
691+
env_list=[f"{key}={value}" for key, value in server_config.env.items()] or None,
692+
),
693+
"stdio",
694+
)
695+
696+
if isinstance(server_config, RemoteMCPServer):
697+
transport_type = "sse" if (server_config.transport == "sse") else _infer_transport_type(server_config.url)
698+
auth = server_config.auth
699+
if auth in (None, "oauth"):
700+
auth = OAuth(mcp_url=server_config.url, token_storage=_build_token_storage())
701+
702+
headers = {key: _interpolate_string(value) for key, value in server_config.headers.items()}
703+
if transport_type == "http":
704+
return (
705+
StreamableHttpTransport(
706+
url=server_config.url,
707+
headers=headers,
708+
auth=auth,
709+
sse_read_timeout=server_config.sse_read_timeout,
710+
),
711+
"http",
712+
)
713+
return (
714+
SSETransport(
715+
url=server_config.url,
716+
headers=headers,
717+
auth=auth,
718+
sse_read_timeout=server_config.sse_read_timeout,
719+
),
720+
"sse",
721+
)
722+
723+
raise ValueError("Unsupported single-server MCP config type.")
724+
725+
660726
def _interpolate_string(value: str) -> str:
661727
"""Interpolate environment variables in a single string.
662728

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "mcp-compressor"
3-
version = "0.2.9"
3+
version = "0.2.10"
44
description = "An MCP server wrapper for reducing tokens consumed by MCP tools."
55
authors = [{ name = "Tim Esler", email = "tesler@atlassian.com" }]
66
readme = "README.md"
@@ -32,6 +32,7 @@ dependencies = [
3232
"toons>=0.5.3",
3333
"typer>=0.16.0",
3434
"uvicorn>=0.30.0",
35+
"click>=8.3.1",
3536
]
3637

3738
[project.urls]

tests/test_main.py

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
from typing import Any, cast
88

99
import pytest
10+
from fastmcp.client.auth.bearer import BearerAuth
1011
from fastmcp.client.auth.oauth import ClientNotFoundError, OAuth
11-
from fastmcp.client.transports import MCPConfigTransport, SSETransport, StdioTransport, StreamableHttpTransport
12+
from fastmcp.client.transports import SSETransport, StdioTransport, StreamableHttpTransport
1213
from fastmcp.exceptions import McpError
1314
from typer.testing import CliRunner
1415

@@ -19,6 +20,7 @@
1920
suppress_recoverable_oauth_traceback_logging,
2021
)
2122
from mcp_compressor.main import (
23+
_get_single_server_transport_from_mcp_config,
2224
_get_sse_transport,
2325
_get_stdio_transport,
2426
_get_streamable_http_transport,
@@ -120,6 +122,51 @@ def test_parse_single_server_mcp_config_rejects_multiple_servers() -> None:
120122
# Tests for transport creation functions
121123

122124

125+
def test_get_single_server_transport_from_mcp_config_remote_defaults_to_oauth(monkeypatch: pytest.MonkeyPatch) -> None:
126+
config_json = '{"mcpServers": {"weather": {"url": "https://example.com/mcp"}}}'
127+
parsed = _parse_single_server_mcp_config([config_json])
128+
assert parsed is not None
129+
config, _ = parsed
130+
131+
token_storage = object()
132+
monkeypatch.setattr(main_module, "_build_token_storage", lambda: token_storage)
133+
134+
transport, transport_type = _get_single_server_transport_from_mcp_config(config=config)
135+
136+
assert transport_type == "http"
137+
assert isinstance(transport, StreamableHttpTransport)
138+
assert isinstance(transport.auth, OAuth)
139+
assert transport.auth.mcp_url == "https://example.com/mcp"
140+
assert transport.auth._token_storage is token_storage
141+
142+
143+
def test_get_single_server_transport_from_mcp_config_remote_preserves_explicit_auth() -> None:
144+
config_json = '{"mcpServers": {"weather": {"url": "https://example.com/mcp", "auth": "abc"}}}'
145+
parsed = _parse_single_server_mcp_config([config_json])
146+
assert parsed is not None
147+
config, _ = parsed
148+
149+
transport, transport_type = _get_single_server_transport_from_mcp_config(config=config)
150+
151+
assert transport_type == "http"
152+
assert isinstance(transport, StreamableHttpTransport)
153+
assert isinstance(transport.auth, BearerAuth)
154+
assert transport.auth.token.get_secret_value() == "abc"
155+
156+
157+
def test_get_single_server_transport_from_mcp_config_sse_uses_only_config_timeout() -> None:
158+
config_json = '{"mcpServers": {"weather": {"url": "https://example.com/sse", "transport": "sse"}}}'
159+
parsed = _parse_single_server_mcp_config([config_json])
160+
assert parsed is not None
161+
config, _ = parsed
162+
163+
transport, transport_type = _get_single_server_transport_from_mcp_config(config=config)
164+
165+
assert transport_type == "sse"
166+
assert isinstance(transport, SSETransport)
167+
assert transport.sse_read_timeout is None
168+
169+
123170
def test_get_stdio_transport(tmp_path) -> None:
124171
"""Test that stdio transport is created with correct parameters."""
125172
transport = _get_stdio_transport(
@@ -296,7 +343,34 @@ async def fake_async_main(**kwargs: Any) -> None:
296343
assert async_main_called is False
297344

298345

299-
async def test_server_uses_mcp_config_transport_for_single_server_json(monkeypatch: pytest.MonkeyPatch) -> None:
346+
@pytest.mark.parametrize(
347+
("extra_args", "expected_option"),
348+
[
349+
(["--cwd", "."], "--cwd"),
350+
(["--env", "FOO=bar"], "--env"),
351+
(["--header", "Authorization=Bearer abc"], "--header"),
352+
(["--timeout", "30"], "--timeout"),
353+
],
354+
)
355+
def test_single_server_mcp_config_rejects_conflicting_transport_options(
356+
runner: CliRunner, monkeypatch: pytest.MonkeyPatch, extra_args: list[str], expected_option: str
357+
) -> None:
358+
config_json = '{"mcpServers": {"weather": {"url": "https://example.com/mcp"}}}'
359+
async_main_called = False
360+
361+
async def fake_async_main(**kwargs: Any) -> None:
362+
nonlocal async_main_called
363+
async_main_called = True
364+
365+
monkeypatch.setattr(main_module, "_async_main", fake_async_main)
366+
result = runner.invoke(app, [*extra_args, config_json])
367+
368+
assert result.exit_code != 0
369+
assert expected_option in _strip_ansi(result.output)
370+
assert async_main_called is False
371+
372+
373+
async def test_server_uses_single_server_config_transport_directly(monkeypatch: pytest.MonkeyPatch) -> None:
300374
config_json = '{"mcpServers": {"weather": {"command": "uvx", "args": ["mcp-weather"]}}}'
301375
captured: dict[str, Any] = {}
302376

@@ -334,7 +408,9 @@ async def get_compression_stats(self) -> dict[str, int]:
334408
) as mcp:
335409
assert mcp is fake_mcp
336410

337-
assert isinstance(captured["transport"], MCPConfigTransport)
411+
assert isinstance(captured["transport"], StdioTransport)
412+
assert captured["transport"].command == "uvx"
413+
assert captured["transport"].args == ["mcp-weather"]
338414
assert captured["compressed_tools_kwargs"]["server_name"] == "weather"
339415

340416

@@ -576,9 +652,6 @@ def test_version_flag(runner: CliRunner) -> None:
576652

577653
def test_version_short_flag(runner: CliRunner) -> None:
578654
"""-V should be an alias for --version."""
579-
580-
from mcp_compressor.main import app
581-
582655
result = runner.invoke(app, ["-V"])
583656
assert result.exit_code == 0
584657
expected_version = importlib.metadata.version("mcp-compressor")

uv.lock

Lines changed: 3 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)