diff --git a/src/north_mcp_python_sdk/__init__.py b/src/north_mcp_python_sdk/__init__.py index c47bab8..d7fab3e 100644 --- a/src/north_mcp_python_sdk/__init__.py +++ b/src/north_mcp_python_sdk/__init__.py @@ -8,7 +8,12 @@ from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware -from .auth import AuthContextMiddleware, NorthAuthBackend, on_auth_error +from .auth import ( + AuthContextMiddleware, + HeadersContextMiddleware, + NorthAuthBackend, + on_auth_error, +) def is_debug_mode() -> bool: @@ -66,5 +71,6 @@ def _add_middleware(self, app: Starlette) -> None: on_error=on_auth_error, ), Middleware(AuthContextMiddleware, debug=self._debug), + Middleware(HeadersContextMiddleware, debug=self._debug), ] app.user_middleware.extend(middleware) diff --git a/src/north_mcp_python_sdk/auth.py b/src/north_mcp_python_sdk/auth.py index 1d0762d..0113623 100644 --- a/src/north_mcp_python_sdk/auth.py +++ b/src/north_mcp_python_sdk/auth.py @@ -1,6 +1,7 @@ import base64 import contextvars import logging +from typing import Any import jwt from pydantic import BaseModel, Field, ValidationError @@ -35,6 +36,10 @@ def __init__( "north_auth_context", default=None ) +headers_context_var = contextvars.ContextVar[dict[str, Any] | None]( + "north_headers_context", default=None +) + def on_auth_error(request: HTTPConnection, exc: AuthenticationError) -> JSONResponse: return JSONResponse({"error": str(exc)}, status_code=401) @@ -48,6 +53,14 @@ def get_authenticated_user() -> AuthenticatedNorthUser: return user +def get_raw_headers() -> dict[str, Any]: + headers = headers_context_var.get() + if not headers: + raise Exception("headers not found in context") + + return headers + + class AuthContextMiddleware: """ Middleware that extracts the authenticated user from the request @@ -83,6 +96,32 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send): auth_context_var.reset(token) +class HeadersContextMiddleware: + """ + Middleware that sets the request headers in a contextvar for easy access + throughout the request lifecycle. + """ + + def __init__(self, app: ASGIApp, debug: bool = False): + self.app = app + self.debug = debug + self.logger = logging.getLogger("NorthMCP.HeadersContext") + if debug: + self.logger.setLevel(logging.DEBUG) + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + if scope["type"] == "lifespan": + return await self.app(scope, receive, send) + + headers = dict(scope.get("headers", {})) + self.logger.debug("Setting request headers in context: %s", headers) + token = headers_context_var.set(headers) + try: + await self.app(scope, receive, send) + finally: + headers_context_var.reset(token) + + class NorthAuthBackend(AuthenticationBackend): """ Authentication backend that validates Bearer tokens.