Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions airbyte/mcp/_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) 2025 Airbyte, Inc., all rights reserved.
"""ASGI middleware for extracting authentication headers in MCP HTTP/SSE modes."""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING

from airbyte.mcp._request_context import (
CLOUD_CLIENT_ID_CVAR,
CLOUD_CLIENT_SECRET_CVAR,
CLOUD_WORKSPACE_ID_CVAR,
)


if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, MutableMapping

logger = logging.getLogger(__name__)


class HeaderAuthMiddleware:
"""ASGI middleware that extracts Airbyte Cloud authentication from HTTP headers.

This middleware runs only in HTTP/SSE modes of the MCP server. It extracts
authentication values from HTTP headers and stores them in ContextVars for
the duration of the request.

Supported headers (case-insensitive):
- X-Airbyte-Cloud-Client-Id or Airbyte-Cloud-Client-Id
- X-Airbyte-Cloud-Client-Secret or Airbyte-Cloud-Client-Secret
- X-Airbyte-Cloud-Workspace-Id or Airbyte-Cloud-Workspace-Id
"""

def __init__(self, app: Callable) -> None:
"""Initialize the middleware.

Args:
app: The ASGI application to wrap
"""
self.app = app

async def __call__(
self,
scope: MutableMapping,
receive: Callable[[], Awaitable[MutableMapping]],
send: Callable[[MutableMapping], Awaitable[None]],
) -> None:
"""Process the ASGI request.

Args:
scope: ASGI scope dictionary
receive: ASGI receive callable
send: ASGI send callable
"""
if scope["type"] != "http":
await self.app(scope, receive, send)
return

headers = scope.get("headers", [])
header_dict = {name.decode().lower(): value.decode() for name, value in headers}

client_id = self._get_header_value(
header_dict, ["x-airbyte-cloud-client-id", "airbyte-cloud-client-id"]
)
client_secret = self._get_header_value(
header_dict, ["x-airbyte-cloud-client-secret", "airbyte-cloud-client-secret"]
)
workspace_id = self._get_header_value(
header_dict, ["x-airbyte-cloud-workspace-id", "airbyte-cloud-workspace-id"]
)

tokens = []
try:
if client_id:
token = CLOUD_CLIENT_ID_CVAR.set(client_id)
tokens.append((CLOUD_CLIENT_ID_CVAR, token))
logger.debug("Set cloud client ID from HTTP header")

if client_secret:
token = CLOUD_CLIENT_SECRET_CVAR.set(client_secret)
tokens.append((CLOUD_CLIENT_SECRET_CVAR, token))
logger.debug("Set cloud client secret from HTTP header")

if workspace_id:
token = CLOUD_WORKSPACE_ID_CVAR.set(workspace_id)
tokens.append((CLOUD_WORKSPACE_ID_CVAR, token))
logger.debug("Set cloud workspace ID from HTTP header")

await self.app(scope, receive, send)

finally:
for cvar, token in tokens:
cvar.reset(token)

def _get_header_value(self, header_dict: dict[str, str], header_names: list[str]) -> str | None:
"""Get a header value by trying multiple possible header names.

Args:
header_dict: Dictionary of lowercase header names to values
header_names: List of possible header names to try (lowercase)

Returns:
The header value if found, otherwise None
"""
for name in header_names:
if name in header_dict:
return header_dict[name]
return None
63 changes: 63 additions & 0 deletions airbyte/mcp/_request_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) 2025 Airbyte, Inc., all rights reserved.
"""Request context management for MCP server HTTP/SSE modes.

This module provides ContextVars for storing per-request authentication values
extracted from HTTP headers. These values are scoped to a single request and
do not pollute the global environment.
"""

from __future__ import annotations

from contextvars import ContextVar

from airbyte.cloud.auth import (
resolve_cloud_client_id,
resolve_cloud_client_secret,
resolve_cloud_workspace_id,
)
from airbyte.secrets import SecretString


CLOUD_CLIENT_ID_CVAR: ContextVar[str | SecretString | None] = ContextVar(
"cloud_client_id", default=None
)
CLOUD_CLIENT_SECRET_CVAR: ContextVar[str | SecretString | None] = ContextVar(
"cloud_client_secret", default=None
)
CLOUD_WORKSPACE_ID_CVAR: ContextVar[str | None] = ContextVar("cloud_workspace_id", default=None)


def get_effective_cloud_client_id() -> SecretString:
"""Get the effective cloud client ID from request context or environment.

Returns:
Client ID from HTTP headers (if set), otherwise from environment variables
"""
header_value = CLOUD_CLIENT_ID_CVAR.get()
if header_value is not None:
return SecretString(header_value)
return resolve_cloud_client_id()


def get_effective_cloud_client_secret() -> SecretString:
"""Get the effective cloud client secret from request context or environment.

Returns:
Client secret from HTTP headers (if set), otherwise from environment variables
"""
header_value = CLOUD_CLIENT_SECRET_CVAR.get()
if header_value is not None:
return SecretString(header_value)
return resolve_cloud_client_secret()


def get_effective_cloud_workspace_id() -> str:
"""Get the effective cloud workspace ID from request context or environment.

Returns:
Workspace ID from HTTP headers (if set), otherwise from environment variables
"""
header_value = CLOUD_WORKSPACE_ID_CVAR.get()
if header_value is not None:
return str(header_value)
return resolve_cloud_workspace_id()
35 changes: 24 additions & 11 deletions airbyte/mcp/cloud_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,15 @@
from pydantic import Field

from airbyte import cloud, get_destination, get_source
from airbyte.cloud.auth import (
resolve_cloud_api_url,
resolve_cloud_client_id,
resolve_cloud_client_secret,
resolve_cloud_workspace_id,
)
from airbyte.cloud.connections import CloudConnection
from airbyte.cloud.connectors import CloudDestination, CloudSource, CustomCloudSourceDefinition
from airbyte.cloud.workspaces import CloudWorkspace
from airbyte.destinations.util import get_noop_destination
from airbyte.mcp._request_context import (
get_effective_cloud_client_id,
get_effective_cloud_client_secret,
get_effective_cloud_workspace_id,
)
from airbyte.mcp._tool_utils import (
check_guid_created_in_session,
mcp_tool,
Expand All @@ -28,12 +27,26 @@


def _get_cloud_workspace() -> CloudWorkspace:
"""Get an authenticated CloudWorkspace using environment variables."""
"""Get an authenticated CloudWorkspace using HTTP headers or environment variables.

When running in HTTP/SSE mode with the HeaderAuthMiddleware, authentication values
are extracted from HTTP headers:
- X-Airbyte-Cloud-Client-Id or Airbyte-Cloud-Client-Id
- X-Airbyte-Cloud-Client-Secret or Airbyte-Cloud-Client-Secret
- X-Airbyte-Cloud-Workspace-Id or Airbyte-Cloud-Workspace-Id

If headers are not provided, falls back to environment variables:
- AIRBYTE_CLOUD_CLIENT_ID
- AIRBYTE_CLOUD_CLIENT_SECRET
- AIRBYTE_CLOUD_WORKSPACE_ID

Returns:
Authenticated CloudWorkspace instance
"""
return CloudWorkspace(
workspace_id=resolve_cloud_workspace_id(),
client_id=resolve_cloud_client_id(),
client_secret=resolve_cloud_client_secret(),
api_root=resolve_cloud_api_url(),
workspace_id=get_effective_cloud_workspace_id(),
client_id=get_effective_cloud_client_id(),
client_secret=get_effective_cloud_client_secret(),
)


Expand Down
12 changes: 10 additions & 2 deletions airbyte/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastmcp import FastMCP

from airbyte._util.meta import set_mcp_mode
from airbyte.mcp._middleware import HeaderAuthMiddleware
from airbyte.mcp._util import initialize_secrets
from airbyte.mcp.cloud_ops import register_cloud_ops_tools
from airbyte.mcp.connector_registry import register_connector_registry_tools
Expand All @@ -16,8 +17,15 @@
set_mcp_mode()
initialize_secrets()

app: FastMCP = FastMCP("airbyte-mcp")
"""The Airbyte MCP Server application instance."""
app: FastMCP = FastMCP(
"airbyte-mcp",
middleware=[HeaderAuthMiddleware], # type: ignore[list-item]
)
"""The Airbyte MCP Server application instance.

When running in HTTP/SSE mode, the HeaderAuthMiddleware extracts authentication
values from HTTP headers and makes them available via ContextVars.
"""

register_connector_registry_tools(app)
register_local_ops_tools(app)
Expand Down