Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/open_stocks_mcp/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from open_stocks_mcp.brokers.schwab import SchwabBroker
from open_stocks_mcp.config import ServerConfig, load_config
from open_stocks_mcp.logging_config import logger, setup_logging
from open_stocks_mcp.server.broker_filter import install_broker_filter
from open_stocks_mcp.server.tool_execution_limits import install_tool_execution_limit
from open_stocks_mcp.server.tool_helpers import (
get_broker_status_data,
Expand Down Expand Up @@ -2299,6 +2300,9 @@ def main(
logger.info("Initializing broker authentication...")
asyncio.run(setup_brokers(username, password, config=config))

# Gate tool access to enabled brokers only (applied after broker setup)
install_broker_filter(server, config.brokers.enabled_brokers)

# Start server regardless of authentication status
try:
if transport == "stdio":
Expand Down
101 changes: 101 additions & 0 deletions src/open_stocks_mcp/server/broker_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Broker availability filter for MCP tool dispatch.

Wraps FastMCP's call_tool and list_tools to enforce ENABLED_BROKERS at
runtime: tools for disabled brokers are hidden from list_tools and return a
descriptive error when called.
"""

from __future__ import annotations

import json
from typing import Any

import mcp.types
from mcp.server.fastmcp import FastMCP

_WRAPPER_ATTR = "_broker_filter_installed"
_ENABLED_ATTR = "_broker_filter_enabled_brokers"

# Tools that operate on server state rather than a specific broker and should
# always remain accessible regardless of ENABLED_BROKERS.
_BROKER_AGNOSTIC_TOOLS: frozenset[str] = frozenset(
{
"list_tools",
"session_status",
"broker_status",
"list_brokers",
"rate_limit_status",
"metrics_summary",
"aggregated_portfolio",
"broker_comparison",
"health_check",
}
)


def _tool_broker(tool_name: str) -> str | None:
"""Return the broker name required by *tool_name*, or ``None`` if agnostic.

Mapping rules (evaluated in order):
1. Tools in ``_BROKER_AGNOSTIC_TOOLS`` → no broker required.
2. Tool name contains ``"schwab"`` → requires ``"schwab"``.
3. Everything else → requires ``"robinhood"``.
"""
if tool_name in _BROKER_AGNOSTIC_TOOLS:
return None
if "schwab" in tool_name:
return "schwab"
return "robinhood"


def install_broker_filter(mcp_server: FastMCP, enabled_brokers: list[str]) -> None:
"""Wrap *mcp_server* to enforce ``ENABLED_BROKERS`` at dispatch time.

* ``call_tool``: tools for a disabled broker return a structured error
instead of executing.
* ``list_tools``: tools for disabled brokers are omitted from the listing.

Idempotent — re-calling with a new *enabled_brokers* list updates the
active set without stacking additional wrappers.
"""
if getattr(mcp_server, _WRAPPER_ATTR, False):
setattr(mcp_server, _ENABLED_ATTR, list(enabled_brokers))
return

setattr(mcp_server, _ENABLED_ATTR, list(enabled_brokers))
original_call_tool = mcp_server.call_tool
original_list_tools = mcp_server.list_tools

async def _filtered_call_tool(tool_name: str, arguments: dict[str, Any]) -> Any:
active_enabled: list[str] = getattr(mcp_server, _ENABLED_ATTR, ["robinhood"])
required = _tool_broker(tool_name)
if required is not None and required not in active_enabled:
error_data = {
"status": "error",
"error": (
f"Tool '{tool_name}' requires the '{required}' broker, "
f"which is not enabled. "
f"Set ENABLED_BROKERS to include '{required}' and restart the server. "
f"Currently enabled: {active_enabled}."
),
}
return mcp.types.CallToolResult(
content=[
mcp.types.TextContent(type="text", text=json.dumps(error_data))
],
isError=True,
)
return await original_call_tool(tool_name, arguments)

async def _filtered_list_tools() -> list[mcp.types.Tool]:
active_enabled: list[str] = getattr(mcp_server, _ENABLED_ATTR, ["robinhood"])
all_tools: list[mcp.types.Tool] = await original_list_tools()
return [
t
for t in all_tools
if (req := _tool_broker(t.name)) is None or req in active_enabled
]

mcp_server.call_tool = _filtered_call_tool # type: ignore[assignment]
mcp_server.list_tools = _filtered_list_tools # type: ignore[method-assign]
setattr(mcp_server, _WRAPPER_ATTR, True)
218 changes: 218 additions & 0 deletions tests/unit/test_broker_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
"""Tests for broker availability filter (ENABLED_BROKERS enforcement)."""

from __future__ import annotations

import json
from typing import Any

import pytest
from mcp.server.fastmcp import FastMCP

from open_stocks_mcp.server.broker_filter import (
_BROKER_AGNOSTIC_TOOLS,
_tool_broker,
install_broker_filter,
)

# ---------------------------------------------------------------------------
# _tool_broker mapping
# ---------------------------------------------------------------------------


@pytest.mark.unit
@pytest.mark.journey_system
class TestToolBrokerMapping:
def test_agnostic_tools_return_none(self) -> None:
for name in _BROKER_AGNOSTIC_TOOLS:
assert _tool_broker(name) is None, f"{name} should be agnostic"

def test_schwab_tool_detected_by_name_prefix(self) -> None:
assert _tool_broker("schwab_quote") == "schwab"
assert _tool_broker("schwab_buy_stock_market") == "schwab"
assert _tool_broker("schwab_account_numbers") == "schwab"

def test_schwab_tool_detected_by_substring(self) -> None:
# Any tool whose name contains "schwab" maps to schwab broker
assert _tool_broker("get_schwab_portfolio") == "schwab"

def test_robinhood_tool_is_default(self) -> None:
assert _tool_broker("account_info") == "robinhood"
assert _tool_broker("portfolio") == "robinhood"
assert _tool_broker("stock_quote") == "robinhood"
assert _tool_broker("positions") == "robinhood"

def test_unknown_tool_defaults_to_robinhood(self) -> None:
assert _tool_broker("some_new_tool") == "robinhood"


# ---------------------------------------------------------------------------
# install_broker_filter — call_tool enforcement
# ---------------------------------------------------------------------------


def _make_server(*tool_names: str) -> FastMCP:
"""Create a FastMCP instance with zero-argument stub tools."""
server = FastMCP("test")
for name in tool_names:

async def _stub() -> dict[str, Any]:
return {"result": "ok"}

_stub.__name__ = name
server.tool()(_stub)
return server


def _is_blocked(result: Any) -> bool:
"""True when the broker filter returned an error CallToolResult."""
return bool(getattr(result, "isError", False))


def _is_allowed(result: Any) -> bool:
"""True when the broker filter passed the call through to the real tool."""
return not _is_blocked(result)


@pytest.mark.unit
@pytest.mark.journey_system
class TestInstallBrokerFilter:
@pytest.mark.asyncio
async def test_schwab_tool_blocked_when_robinhood_only(self) -> None:
server = _make_server("schwab_quote", "portfolio")
install_broker_filter(server, ["robinhood"])

result = await server.call_tool("schwab_quote", {})
assert _is_blocked(result)
payload = json.loads(result.content[0].text) # type: ignore[union-attr]
assert payload["status"] == "error"
assert "schwab" in payload["error"]

@pytest.mark.asyncio
async def test_robinhood_tool_allowed_when_robinhood_enabled(self) -> None:
server = _make_server("portfolio")
install_broker_filter(server, ["robinhood"])

result = await server.call_tool("portfolio", {})
assert _is_allowed(result)

@pytest.mark.asyncio
async def test_robinhood_tool_blocked_when_schwab_only(self) -> None:
server = _make_server("portfolio")
install_broker_filter(server, ["schwab"])

result = await server.call_tool("portfolio", {})
assert _is_blocked(result)
payload = json.loads(result.content[0].text) # type: ignore[union-attr]
assert "robinhood" in payload["error"]

@pytest.mark.asyncio
async def test_schwab_tool_allowed_when_schwab_enabled(self) -> None:
server = _make_server("schwab_quote")
install_broker_filter(server, ["schwab"])

result = await server.call_tool("schwab_quote", {})
assert _is_allowed(result)

@pytest.mark.asyncio
async def test_all_tools_allowed_when_both_enabled(self) -> None:
server = _make_server("schwab_quote", "portfolio")
install_broker_filter(server, ["robinhood", "schwab"])

assert _is_allowed(await server.call_tool("schwab_quote", {}))
assert _is_allowed(await server.call_tool("portfolio", {}))

@pytest.mark.asyncio
async def test_agnostic_tools_always_allowed(self) -> None:
server = _make_server("broker_status", "list_brokers")
install_broker_filter(server, []) # no brokers enabled

assert _is_allowed(await server.call_tool("broker_status", {}))
assert _is_allowed(await server.call_tool("list_brokers", {}))

@pytest.mark.asyncio
async def test_error_message_names_required_broker(self) -> None:
server = _make_server("schwab_orders")
install_broker_filter(server, ["robinhood"])

result = await server.call_tool("schwab_orders", {})
assert _is_blocked(result)
payload = json.loads(result.content[0].text) # type: ignore[union-attr]
assert "schwab" in payload["error"]
assert "ENABLED_BROKERS" in payload["error"]


# ---------------------------------------------------------------------------
# install_broker_filter — list_tools filtering
# ---------------------------------------------------------------------------


@pytest.mark.unit
@pytest.mark.journey_system
class TestBrokerFilterListTools:
@pytest.mark.asyncio
async def test_schwab_tools_hidden_when_robinhood_only(self) -> None:
server = _make_server("schwab_quote", "portfolio", "broker_status")
install_broker_filter(server, ["robinhood"])

tools = await server.list_tools()
names = {t.name for t in tools}
assert "schwab_quote" not in names
assert "portfolio" in names
assert "broker_status" in names

@pytest.mark.asyncio
async def test_robinhood_tools_hidden_when_schwab_only(self) -> None:
server = _make_server("schwab_quote", "portfolio", "list_brokers")
install_broker_filter(server, ["schwab"])

tools = await server.list_tools()
names = {t.name for t in tools}
assert "portfolio" not in names
assert "schwab_quote" in names
assert "list_brokers" in names

@pytest.mark.asyncio
async def test_all_tools_visible_when_both_enabled(self) -> None:
server = _make_server("schwab_quote", "portfolio")
install_broker_filter(server, ["robinhood", "schwab"])

tools = await server.list_tools()
names = {t.name for t in tools}
assert "schwab_quote" in names
assert "portfolio" in names


# ---------------------------------------------------------------------------
# Idempotency
# ---------------------------------------------------------------------------


@pytest.mark.unit
@pytest.mark.journey_system
class TestBrokerFilterIdempotency:
@pytest.mark.asyncio
async def test_reinstall_updates_enabled_list(self) -> None:
server = _make_server("schwab_quote")
install_broker_filter(server, ["robinhood"])

# Schwab blocked initially
result = await server.call_tool("schwab_quote", {})
assert _is_blocked(result)

# Re-install with Schwab enabled — no double-wrapping
install_broker_filter(server, ["robinhood", "schwab"])
result = await server.call_tool("schwab_quote", {})
assert _is_allowed(result)

@pytest.mark.asyncio
async def test_wrapper_not_stacked_on_reinstall(self) -> None:
server = _make_server("portfolio")
original_call = server.call_tool
install_broker_filter(server, ["robinhood"])
install_broker_filter(server, ["robinhood"])
# Should still be the same single wrapper, not a double-wrapped chain
assert server.call_tool is not original_call
# A second install doesn't replace the wrapper
first_wrapper = server.call_tool
install_broker_filter(server, ["robinhood", "schwab"])
assert server.call_tool is first_wrapper
Loading