|
7 | 7 | from typing import Any, cast |
8 | 8 |
|
9 | 9 | import pytest |
| 10 | +from fastmcp.client.auth.bearer import BearerAuth |
10 | 11 | from fastmcp.client.auth.oauth import ClientNotFoundError, OAuth |
11 | | -from fastmcp.client.transports import MCPConfigTransport, SSETransport, StdioTransport, StreamableHttpTransport |
| 12 | +from fastmcp.client.transports import SSETransport, StdioTransport, StreamableHttpTransport |
12 | 13 | from fastmcp.exceptions import McpError |
13 | 14 | from typer.testing import CliRunner |
14 | 15 |
|
|
19 | 20 | suppress_recoverable_oauth_traceback_logging, |
20 | 21 | ) |
21 | 22 | from mcp_compressor.main import ( |
| 23 | + _get_single_server_transport_from_mcp_config, |
22 | 24 | _get_sse_transport, |
23 | 25 | _get_stdio_transport, |
24 | 26 | _get_streamable_http_transport, |
@@ -120,6 +122,51 @@ def test_parse_single_server_mcp_config_rejects_multiple_servers() -> None: |
120 | 122 | # Tests for transport creation functions |
121 | 123 |
|
122 | 124 |
|
| 125 | +def test_get_single_server_transport_from_mcp_config_remote_defaults_to_oauth(monkeypatch: pytest.MonkeyPatch) -> None: |
| 126 | + config_json = '{"mcpServers": {"weather": {"url": "https://example.com/mcp"}}}' |
| 127 | + parsed = _parse_single_server_mcp_config([config_json]) |
| 128 | + assert parsed is not None |
| 129 | + config, _ = parsed |
| 130 | + |
| 131 | + token_storage = object() |
| 132 | + monkeypatch.setattr(main_module, "_build_token_storage", lambda: token_storage) |
| 133 | + |
| 134 | + transport, transport_type = _get_single_server_transport_from_mcp_config(config=config) |
| 135 | + |
| 136 | + assert transport_type == "http" |
| 137 | + assert isinstance(transport, StreamableHttpTransport) |
| 138 | + assert isinstance(transport.auth, OAuth) |
| 139 | + assert transport.auth.mcp_url == "https://example.com/mcp" |
| 140 | + assert transport.auth._token_storage is token_storage |
| 141 | + |
| 142 | + |
| 143 | +def test_get_single_server_transport_from_mcp_config_remote_preserves_explicit_auth() -> None: |
| 144 | + config_json = '{"mcpServers": {"weather": {"url": "https://example.com/mcp", "auth": "abc"}}}' |
| 145 | + parsed = _parse_single_server_mcp_config([config_json]) |
| 146 | + assert parsed is not None |
| 147 | + config, _ = parsed |
| 148 | + |
| 149 | + transport, transport_type = _get_single_server_transport_from_mcp_config(config=config) |
| 150 | + |
| 151 | + assert transport_type == "http" |
| 152 | + assert isinstance(transport, StreamableHttpTransport) |
| 153 | + assert isinstance(transport.auth, BearerAuth) |
| 154 | + assert transport.auth.token.get_secret_value() == "abc" |
| 155 | + |
| 156 | + |
| 157 | +def test_get_single_server_transport_from_mcp_config_sse_uses_only_config_timeout() -> None: |
| 158 | + config_json = '{"mcpServers": {"weather": {"url": "https://example.com/sse", "transport": "sse"}}}' |
| 159 | + parsed = _parse_single_server_mcp_config([config_json]) |
| 160 | + assert parsed is not None |
| 161 | + config, _ = parsed |
| 162 | + |
| 163 | + transport, transport_type = _get_single_server_transport_from_mcp_config(config=config) |
| 164 | + |
| 165 | + assert transport_type == "sse" |
| 166 | + assert isinstance(transport, SSETransport) |
| 167 | + assert transport.sse_read_timeout is None |
| 168 | + |
| 169 | + |
123 | 170 | def test_get_stdio_transport(tmp_path) -> None: |
124 | 171 | """Test that stdio transport is created with correct parameters.""" |
125 | 172 | transport = _get_stdio_transport( |
@@ -296,7 +343,34 @@ async def fake_async_main(**kwargs: Any) -> None: |
296 | 343 | assert async_main_called is False |
297 | 344 |
|
298 | 345 |
|
299 | | -async def test_server_uses_mcp_config_transport_for_single_server_json(monkeypatch: pytest.MonkeyPatch) -> None: |
| 346 | +@pytest.mark.parametrize( |
| 347 | + ("extra_args", "expected_option"), |
| 348 | + [ |
| 349 | + (["--cwd", "."], "--cwd"), |
| 350 | + (["--env", "FOO=bar"], "--env"), |
| 351 | + (["--header", "Authorization=Bearer abc"], "--header"), |
| 352 | + (["--timeout", "30"], "--timeout"), |
| 353 | + ], |
| 354 | +) |
| 355 | +def test_single_server_mcp_config_rejects_conflicting_transport_options( |
| 356 | + runner: CliRunner, monkeypatch: pytest.MonkeyPatch, extra_args: list[str], expected_option: str |
| 357 | +) -> None: |
| 358 | + config_json = '{"mcpServers": {"weather": {"url": "https://example.com/mcp"}}}' |
| 359 | + async_main_called = False |
| 360 | + |
| 361 | + async def fake_async_main(**kwargs: Any) -> None: |
| 362 | + nonlocal async_main_called |
| 363 | + async_main_called = True |
| 364 | + |
| 365 | + monkeypatch.setattr(main_module, "_async_main", fake_async_main) |
| 366 | + result = runner.invoke(app, [*extra_args, config_json]) |
| 367 | + |
| 368 | + assert result.exit_code != 0 |
| 369 | + assert expected_option in _strip_ansi(result.output) |
| 370 | + assert async_main_called is False |
| 371 | + |
| 372 | + |
| 373 | +async def test_server_uses_single_server_config_transport_directly(monkeypatch: pytest.MonkeyPatch) -> None: |
300 | 374 | config_json = '{"mcpServers": {"weather": {"command": "uvx", "args": ["mcp-weather"]}}}' |
301 | 375 | captured: dict[str, Any] = {} |
302 | 376 |
|
@@ -334,7 +408,9 @@ async def get_compression_stats(self) -> dict[str, int]: |
334 | 408 | ) as mcp: |
335 | 409 | assert mcp is fake_mcp |
336 | 410 |
|
337 | | - assert isinstance(captured["transport"], MCPConfigTransport) |
| 411 | + assert isinstance(captured["transport"], StdioTransport) |
| 412 | + assert captured["transport"].command == "uvx" |
| 413 | + assert captured["transport"].args == ["mcp-weather"] |
338 | 414 | assert captured["compressed_tools_kwargs"]["server_name"] == "weather" |
339 | 415 |
|
340 | 416 |
|
@@ -576,9 +652,6 @@ def test_version_flag(runner: CliRunner) -> None: |
576 | 652 |
|
577 | 653 | def test_version_short_flag(runner: CliRunner) -> None: |
578 | 654 | """-V should be an alias for --version.""" |
579 | | - |
580 | | - from mcp_compressor.main import app |
581 | | - |
582 | 655 | result = runner.invoke(app, ["-V"]) |
583 | 656 | assert result.exit_code == 0 |
584 | 657 | expected_version = importlib.metadata.version("mcp-compressor") |
|
0 commit comments