diff --git a/airbyte/mcp/_middleware.py b/airbyte/mcp/_middleware.py new file mode 100644 index 000000000..fe882293e --- /dev/null +++ b/airbyte/mcp/_middleware.py @@ -0,0 +1,109 @@ +# Copyright (c) 2025 Airbyte, Inc., all rights reserved. +"""ASGI middleware for extracting authentication headers in MCP HTTP/SSE modes.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from airbyte.mcp._request_context import ( + CLOUD_CLIENT_ID_CVAR, + CLOUD_CLIENT_SECRET_CVAR, + CLOUD_WORKSPACE_ID_CVAR, +) + + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable, MutableMapping + +logger = logging.getLogger(__name__) + + +class HeaderAuthMiddleware: + """ASGI middleware that extracts Airbyte Cloud authentication from HTTP headers. + + This middleware runs only in HTTP/SSE modes of the MCP server. It extracts + authentication values from HTTP headers and stores them in ContextVars for + the duration of the request. + + Supported headers (case-insensitive): + - X-Airbyte-Cloud-Client-Id or Airbyte-Cloud-Client-Id + - X-Airbyte-Cloud-Client-Secret or Airbyte-Cloud-Client-Secret + - X-Airbyte-Cloud-Workspace-Id or Airbyte-Cloud-Workspace-Id + """ + + def __init__(self, app: Callable) -> None: + """Initialize the middleware. + + Args: + app: The ASGI application to wrap + """ + self.app = app + + async def __call__( + self, + scope: MutableMapping, + receive: Callable[[], Awaitable[MutableMapping]], + send: Callable[[MutableMapping], Awaitable[None]], + ) -> None: + """Process the ASGI request. + + Args: + scope: ASGI scope dictionary + receive: ASGI receive callable + send: ASGI send callable + """ + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + headers = scope.get("headers", []) + header_dict = {name.decode().lower(): value.decode() for name, value in headers} + + client_id = self._get_header_value( + header_dict, ["x-airbyte-cloud-client-id", "airbyte-cloud-client-id"] + ) + client_secret = self._get_header_value( + header_dict, ["x-airbyte-cloud-client-secret", "airbyte-cloud-client-secret"] + ) + workspace_id = self._get_header_value( + header_dict, ["x-airbyte-cloud-workspace-id", "airbyte-cloud-workspace-id"] + ) + + tokens = [] + try: + if client_id: + token = CLOUD_CLIENT_ID_CVAR.set(client_id) + tokens.append((CLOUD_CLIENT_ID_CVAR, token)) + logger.debug("Set cloud client ID from HTTP header") + + if client_secret: + token = CLOUD_CLIENT_SECRET_CVAR.set(client_secret) + tokens.append((CLOUD_CLIENT_SECRET_CVAR, token)) + logger.debug("Set cloud client secret from HTTP header") + + if workspace_id: + token = CLOUD_WORKSPACE_ID_CVAR.set(workspace_id) + tokens.append((CLOUD_WORKSPACE_ID_CVAR, token)) + logger.debug("Set cloud workspace ID from HTTP header") + + await self.app(scope, receive, send) + + finally: + for cvar, token in tokens: + cvar.reset(token) + + def _get_header_value(self, header_dict: dict[str, str], header_names: list[str]) -> str | None: + """Get a header value by trying multiple possible header names. + + Args: + header_dict: Dictionary of lowercase header names to values + header_names: List of possible header names to try (lowercase) + + Returns: + The header value if found, otherwise None + """ + for name in header_names: + if name in header_dict: + return header_dict[name] + return None diff --git a/airbyte/mcp/_request_context.py b/airbyte/mcp/_request_context.py new file mode 100644 index 000000000..dabfc59b2 --- /dev/null +++ b/airbyte/mcp/_request_context.py @@ -0,0 +1,63 @@ +# Copyright (c) 2025 Airbyte, Inc., all rights reserved. +"""Request context management for MCP server HTTP/SSE modes. + +This module provides ContextVars for storing per-request authentication values +extracted from HTTP headers. These values are scoped to a single request and +do not pollute the global environment. +""" + +from __future__ import annotations + +from contextvars import ContextVar + +from airbyte.cloud.auth import ( + resolve_cloud_client_id, + resolve_cloud_client_secret, + resolve_cloud_workspace_id, +) +from airbyte.secrets import SecretString + + +CLOUD_CLIENT_ID_CVAR: ContextVar[str | SecretString | None] = ContextVar( + "cloud_client_id", default=None +) +CLOUD_CLIENT_SECRET_CVAR: ContextVar[str | SecretString | None] = ContextVar( + "cloud_client_secret", default=None +) +CLOUD_WORKSPACE_ID_CVAR: ContextVar[str | None] = ContextVar("cloud_workspace_id", default=None) + + +def get_effective_cloud_client_id() -> SecretString: + """Get the effective cloud client ID from request context or environment. + + Returns: + Client ID from HTTP headers (if set), otherwise from environment variables + """ + header_value = CLOUD_CLIENT_ID_CVAR.get() + if header_value is not None: + return SecretString(header_value) + return resolve_cloud_client_id() + + +def get_effective_cloud_client_secret() -> SecretString: + """Get the effective cloud client secret from request context or environment. + + Returns: + Client secret from HTTP headers (if set), otherwise from environment variables + """ + header_value = CLOUD_CLIENT_SECRET_CVAR.get() + if header_value is not None: + return SecretString(header_value) + return resolve_cloud_client_secret() + + +def get_effective_cloud_workspace_id() -> str: + """Get the effective cloud workspace ID from request context or environment. + + Returns: + Workspace ID from HTTP headers (if set), otherwise from environment variables + """ + header_value = CLOUD_WORKSPACE_ID_CVAR.get() + if header_value is not None: + return str(header_value) + return resolve_cloud_workspace_id() diff --git a/airbyte/mcp/cloud_ops.py b/airbyte/mcp/cloud_ops.py index 305731006..2fdd4f4d2 100644 --- a/airbyte/mcp/cloud_ops.py +++ b/airbyte/mcp/cloud_ops.py @@ -8,16 +8,15 @@ from pydantic import Field from airbyte import cloud, get_destination, get_source -from airbyte.cloud.auth import ( - resolve_cloud_api_url, - resolve_cloud_client_id, - resolve_cloud_client_secret, - resolve_cloud_workspace_id, -) from airbyte.cloud.connections import CloudConnection from airbyte.cloud.connectors import CloudDestination, CloudSource, CustomCloudSourceDefinition from airbyte.cloud.workspaces import CloudWorkspace from airbyte.destinations.util import get_noop_destination +from airbyte.mcp._request_context import ( + get_effective_cloud_client_id, + get_effective_cloud_client_secret, + get_effective_cloud_workspace_id, +) from airbyte.mcp._tool_utils import ( check_guid_created_in_session, mcp_tool, @@ -28,12 +27,26 @@ def _get_cloud_workspace() -> CloudWorkspace: - """Get an authenticated CloudWorkspace using environment variables.""" + """Get an authenticated CloudWorkspace using HTTP headers or environment variables. + + When running in HTTP/SSE mode with the HeaderAuthMiddleware, authentication values + are extracted from HTTP headers: + - X-Airbyte-Cloud-Client-Id or Airbyte-Cloud-Client-Id + - X-Airbyte-Cloud-Client-Secret or Airbyte-Cloud-Client-Secret + - X-Airbyte-Cloud-Workspace-Id or Airbyte-Cloud-Workspace-Id + + If headers are not provided, falls back to environment variables: + - AIRBYTE_CLOUD_CLIENT_ID + - AIRBYTE_CLOUD_CLIENT_SECRET + - AIRBYTE_CLOUD_WORKSPACE_ID + + Returns: + Authenticated CloudWorkspace instance + """ return CloudWorkspace( - workspace_id=resolve_cloud_workspace_id(), - client_id=resolve_cloud_client_id(), - client_secret=resolve_cloud_client_secret(), - api_root=resolve_cloud_api_url(), + workspace_id=get_effective_cloud_workspace_id(), + client_id=get_effective_cloud_client_id(), + client_secret=get_effective_cloud_client_secret(), ) diff --git a/airbyte/mcp/server.py b/airbyte/mcp/server.py index 9d6ffda5b..c2138dd1c 100644 --- a/airbyte/mcp/server.py +++ b/airbyte/mcp/server.py @@ -7,6 +7,7 @@ from fastmcp import FastMCP from airbyte._util.meta import set_mcp_mode +from airbyte.mcp._middleware import HeaderAuthMiddleware from airbyte.mcp._util import initialize_secrets from airbyte.mcp.cloud_ops import register_cloud_ops_tools from airbyte.mcp.connector_registry import register_connector_registry_tools @@ -16,8 +17,15 @@ set_mcp_mode() initialize_secrets() -app: FastMCP = FastMCP("airbyte-mcp") -"""The Airbyte MCP Server application instance.""" +app: FastMCP = FastMCP( + "airbyte-mcp", + middleware=[HeaderAuthMiddleware], # type: ignore[list-item] +) +"""The Airbyte MCP Server application instance. + +When running in HTTP/SSE mode, the HeaderAuthMiddleware extracts authentication +values from HTTP headers and makes them available via ContextVars. +""" register_connector_registry_tools(app) register_local_ops_tools(app)