Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from jose import JWTError, jwt

from src.config import settings
from src.database.control_plane_store import control_plane_store
from src.database.api_key_store import APIKeyStore
from src.database.user_store import UserStore
from src.pipelines.ingest import IngestPipeline
Expand Down Expand Up @@ -288,7 +289,7 @@ async def require_user(current_user: Optional[dict] = Depends(get_current_user))


# ═══════════════════════════════════════════════════════════════════════════
# Sliding-window rate limiter (in-process, per-key)
# Sliding-window rate limiter
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _SlidingWindowRateLimiter class and its instance _rate_limiter appear to be redundant now that rate limiting logic has been moved to ControlPlaneStore. Consider removing them and updating the associated tests to avoid maintaining dead code.

# ═══════════════════════════════════════════════════════════════════════════

class _SlidingWindowRateLimiter:
Expand Down Expand Up @@ -329,7 +330,11 @@ async def enforce_rate_limit(
) -> dict:
"""Raise 429 if the caller has exceeded their per-minute quota."""
identity = user.get("id", "anonymous")
allowed, remaining = await _rate_limiter.check(identity)
allowed, remaining = await control_plane_store.check_rate_limit(
identity,
max_requests=settings.rate_limit,
window_seconds=60,
)

request.state.rate_limit_remaining = remaining

Expand Down
31 changes: 14 additions & 17 deletions src/api/routes/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from src.config import settings
from src.config.analytics import analytics
from src.database.control_plane_store import control_plane_store

logger = logging.getLogger("xmem.api.admin")

Expand All @@ -44,7 +45,6 @@
# ═══════════════════════════════════════════════════════════════════════════

_admin_collection = None
_admin_sessions: Dict[str, Dict[str, Any]] = {} # token → {user, expires}


def _get_admin_collection():
Expand Down Expand Up @@ -87,15 +87,14 @@ def _verify_admin_token(request: Request) -> Dict[str, Any]:
if auth.startswith("Bearer "):
token = auth[7:]

if not token or token not in _admin_sessions:
if not token:
raise HTTPException(status_code=401, detail="Not authenticated")

session = _admin_sessions[token]
if datetime.now(timezone.utc) > session["expires"]:
del _admin_sessions[token]
user = control_plane_store.get_admin_session(token)
if not user:
raise HTTPException(status_code=401, detail="Session expired")

return session["user"]
return user


# ═══════════════════════════════════════════════════════════════════════════
Expand All @@ -114,11 +113,11 @@ async def admin_login(req: AdminLoginRequest):
raise HTTPException(status_code=401, detail="Invalid credentials")

# Generate session token
token = hashlib.sha256(f"{req.username}{time.time()}".encode()).hexdigest()
_admin_sessions[token] = {
"user": {"username": user["username"], "role": user.get("role", "admin")},
"expires": datetime.now(timezone.utc) + timedelta(hours=24),
}
session = control_plane_store.create_admin_session(
user={"username": user["username"], "role": user.get("role", "admin")},
ttl_seconds=24 * 60 * 60,
)
token = session["token"]

response = JSONResponse({"status": "ok", "token": token, "username": user["username"]})
response.set_cookie(
Expand All @@ -134,8 +133,8 @@ async def admin_login(req: AdminLoginRequest):
@router.post("/api/logout")
async def admin_logout(request: Request):
token = request.cookies.get("xmem_admin_token")
if token and token in _admin_sessions:
del _admin_sessions[token]
if token:
control_plane_store.delete_admin_session(token)
response = JSONResponse({"status": "ok"})
response.delete_cookie("xmem_admin_token")
return response
Expand Down Expand Up @@ -219,7 +218,7 @@ async def ws_live_logs(websocket: WebSocket):

# Validate auth token from query param
token = websocket.query_params.get("token", "")
if token not in _admin_sessions:
if not token or not control_plane_store.get_admin_session(token):
await websocket.close(code=4001, reason="Not authenticated")
return

Expand Down Expand Up @@ -313,7 +312,7 @@ async def _journal_stream():

if not line:
# journalctl exited — send error event and stop
yield f"event: error\ndata: journalctl process exited\n\n"
yield "event: error\ndata: journalctl process exited\n\n"
break

text = line.decode("utf-8", errors="replace").rstrip("\n")
Expand Down Expand Up @@ -385,8 +384,6 @@ async def analytics_summary(request: Request, user: dict = Depends(_verify_admin
now = datetime.now(timezone.utc)
last_24h = now - timedelta(hours=24)
last_7d = now - timedelta(days=7)
last_30d = now - timedelta(days=30)

# API call stats (last 24h)
api_calls_24h = list(collection.aggregate([
{"$match": {"event": "api_call", "ts": {"$gte": last_24h}}},
Expand Down
94 changes: 24 additions & 70 deletions src/api/routes/auth.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
"""Authentication routes for Google OAuth and JWT management."""

import secrets
import string
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from typing import Optional

from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
from fastapi import APIRouter, Depends, Form, HTTPException, status
from fastapi.responses import JSONResponse
from google.auth.transport import requests as google_requests
from google.oauth2 import id_token
from jose import JWTError, jwt
Expand All @@ -16,6 +14,7 @@
from src.config import settings
from src.database.user_store import UserStore
from src.database.api_key_store import APIKeyStore
from src.database.control_plane_store import control_plane_store

router = APIRouter(prefix="/auth", tags=["Authentication"])

Expand All @@ -24,86 +23,44 @@
api_key_store = APIKeyStore()

# ═══════════════════════════════════════════════════════════════════════════
# MCP OAuth Temp Token Store (in-memory with TTL)
# MCP OAuth Temp Token Store
# ═══════════════════════════════════════════════════════════════════════════
_mcp_temp_tokens: Dict[str, Dict[str, Any]] = {}
TEMP_TOKEN_PREFIX = "xm-temp-"
TEMP_TOKEN_TTL_MINUTES = 10
TEMP_TOKEN_LENGTH = 32


def _generate_mcp_temp_token() -> str:
"""Generate a temporary token for MCP OAuth flow."""
alphabet = string.ascii_letters + string.digits
random_part = "".join(secrets.choice(alphabet) for _ in range(TEMP_TOKEN_LENGTH))
return f"{TEMP_TOKEN_PREFIX}{random_part}"
MCP_TEMP_TOKEN_RECORD = "mcp_temp_token"
OAUTH_AUTH_CODE_RECORD = "oauth_auth_code"


def _create_mcp_temp_token(user_id: str) -> str:
def _create_mcp_temp_token(user_id: str) -> dict:
"""Create and store a temporary token for the user."""
token = _generate_mcp_temp_token()
expires_at = datetime.utcnow() + timedelta(minutes=TEMP_TOKEN_TTL_MINUTES)

_mcp_temp_tokens[token] = {
"user_id": user_id,
"created_at": datetime.utcnow(),
"expires_at": expires_at,
"exchanged": False,
}

return token
return control_plane_store.create_single_use_token(
record_type=MCP_TEMP_TOKEN_RECORD,
user_id=user_id,
prefix=TEMP_TOKEN_PREFIX,
ttl_seconds=TEMP_TOKEN_TTL_MINUTES * 60,
)


def _get_and_invalidate_mcp_token(token: str) -> Optional[str]:
"""Validate temp token and return user_id if valid, None otherwise."""
if token not in _mcp_temp_tokens:
return None

token_data = _mcp_temp_tokens[token]

# Check expiry
if datetime.utcnow() > token_data["expires_at"]:
del _mcp_temp_tokens[token]
return None

# Check if already exchanged
if token_data["exchanged"]:
return None

# Mark as exchanged and return user_id
user_id = token_data["user_id"]
del _mcp_temp_tokens[token] # Single-use token
return user_id
return control_plane_store.consume_single_use_token(MCP_TEMP_TOKEN_RECORD, token)


# ═══════════════════════════════════════════════════════════════════════════
# Standard OAuth 2.0 Store (for ChatGPT UI)
# ═══════════════════════════════════════════════════════════════════════════
_oauth_auth_codes: Dict[str, Dict[str, Any]] = {}

def _generate_auth_code(user_id: str) -> str:
"""Generate a standard OAuth 2.0 authorization code."""
alphabet = string.ascii_letters + string.digits
code = "".join(secrets.choice(alphabet) for _ in range(32))

_oauth_auth_codes[code] = {
"user_id": user_id,
"expires_at": datetime.utcnow() + timedelta(minutes=10)
}
return code
return control_plane_store.create_single_use_token(
record_type=OAUTH_AUTH_CODE_RECORD,
user_id=user_id,
prefix="",
ttl_seconds=10 * 60,
)["token"]

def _get_and_invalidate_auth_code(code: str) -> Optional[str]:
"""Validate auth code and return user_id if valid."""
if code not in _oauth_auth_codes:
return None

data = _oauth_auth_codes[code]
del _oauth_auth_codes[code] # Single-use

if datetime.utcnow() > data["expires_at"]:
return None

return data["user_id"]
return control_plane_store.consume_single_use_token(OAUTH_AUTH_CODE_RECORD, code)


# ═══════════════════════════════════════════════════════════════════════════
Expand Down Expand Up @@ -463,9 +420,9 @@ async def generate_mcp_temp_token(current_user: dict = Depends(require_user)):
temp_token = _create_mcp_temp_token(user_id)

return MCPTempTokenResponse(
temp_token=temp_token,
temp_token=temp_token["token"],
expires_in=TEMP_TOKEN_TTL_MINUTES * 60,
expires_at=_mcp_temp_tokens[temp_token]["expires_at"]
expires_at=temp_token["expires_at"],
)


Expand Down Expand Up @@ -535,9 +492,6 @@ async def oauth_approve(request: OAuthApproveRequest, current_user: dict = Depen
return OAuthApproveResponse(code=code)


from fastapi import Form
from fastapi.responses import JSONResponse

@router.post("/oauth/token")
async def oauth_token(
grant_type: str = Form(...),
Expand Down
Loading
Loading