diff --git a/backend/session_manager.py b/backend/session_manager.py index 3c992c9c..a72149af 100644 --- a/backend/session_manager.py +++ b/backend/session_manager.py @@ -6,7 +6,7 @@ import os import uuid from dataclasses import dataclass, field -from datetime import datetime +from datetime import datetime, timedelta from pathlib import Path from typing import Any, Optional @@ -100,6 +100,8 @@ class AgentSession: is_processing: bool = False # True while a submission is being executed broadcaster: Any = None title: str | None = None + # Last time this session was accessed (for idle-unload logic) + last_access: datetime = field(default_factory=datetime.utcnow) # True once this session has been counted against the user's daily # Claude quota. Guards double-counting when the user re-selects an # Anthropic model mid-session. @@ -141,10 +143,21 @@ async def start(self) -> None: self.persistence_store = get_session_store() await self.persistence_store.init() await self.messaging_gateway.start() + # Start background cleanup task to unload long-idle sessions. + self._cleanup_task = asyncio.create_task(self._unload_inactive_sessions_loop()) async def close(self) -> None: """Flush and close shared background resources.""" await self._cleanup_all_sandboxes_on_close() + # Cancel cleanup task + if getattr(self, "_cleanup_task", None): + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + except Exception: + logger.exception("Cleanup task failed during shutdown") await self.messaging_gateway.close() if self.persistence_store is not None: await self.persistence_store.close() @@ -197,6 +210,40 @@ def _create_session_sync( logger.info("Session initialized in %.2fs", t1 - t0) return tool_router, session + def _make_reserved_session( + self, + *, + session_id: str, + user_id: str, + hf_username: str | None, + hf_token: str | None, + submission_queue: asyncio.Queue, + ) -> AgentSession: + """Create a placeholder session that reserves capacity under lock.""" + return AgentSession( + session_id=session_id, + session=None, + tool_router=None, + submission_queue=submission_queue, + user_id=user_id, + hf_username=hf_username, + hf_token=hf_token, + is_active=True, + ) + + async def _release_reserved_session_slot( + self, + session_id: str, + reserved_session: AgentSession | None = None, + ) -> None: + """Remove a reserved placeholder if it is still present.""" + async with self._lock: + current = self.sessions.get(session_id) + if current is None: + return + if current is reserved_session or getattr(current, "session", None) is None: + self.sessions.pop(session_id, None) + def _serialize_messages(self, session: Session) -> list[dict[str, Any]]: return [msg.model_dump(mode="json") for msg in session.context_manager.items] @@ -327,8 +374,17 @@ async def _start_agent_session( async with self._lock: existing = self.sessions.get(agent_session.session_id) if existing: - return existing - self.sessions[agent_session.session_id] = agent_session + # If an earlier coroutine reserved the slot with a placeholder + # AgentSession (session is None), replace it with the real + # agent_session. Otherwise return the existing live session. + if getattr(existing, "session", None) is None: + self.sessions[agent_session.session_id] = agent_session + else: + # Update access time for the existing live session + existing.last_access = datetime.utcnow() + return existing + else: + self.sessions[agent_session.session_id] = agent_session task = asyncio.create_task( self._run_session( @@ -339,6 +395,8 @@ async def _start_agent_session( ) ) agent_session.task = task + # mark last access time when the session has been started + agent_session.last_access = datetime.utcnow() return agent_session @staticmethod @@ -494,6 +552,90 @@ async def _cleanup_persisted_sandbox( last_err, ) + async def _unload_inactive_sessions_once(self) -> None: + """Run one idle-session sweep. + + Extracted so tests can exercise the cleanup behavior without waiting + for the background loop's sleep interval. + """ + now = datetime.utcnow() + to_unload: list[tuple[str, AgentSession]] = [] + async with self._lock: + for sid, agent_session in list(self.sessions.items()): + try: + if not agent_session.is_active: + continue + if getattr(agent_session, "is_processing", False): + continue + last = getattr( + agent_session, + "last_access", + agent_session.created_at, + ) + if now - last > INACTIVE_SESSION_IDLE_THRESHOLD: + # Mark inactive, but keep it resident until the + # snapshot has been persisted successfully. + agent_session.is_active = False + to_unload.append((sid, agent_session)) + except Exception as e: + logger.debug("Skipping unload check for %s: %s", sid, e) + + for sid, agent_session in to_unload: + try: + await self.persist_session_snapshot( + agent_session, + runtime_state=self._runtime_state(agent_session), + status="inactive", + ) + except Exception as e: + logger.warning( + "Failed to persist snapshot before unloading %s: %s", + sid, + e, + ) + # Keep the session in memory so the next cleanup cycle + # can retry persistence and so callers can still inspect + # the session state. + agent_session.is_active = True + continue + + removed = False + async with self._lock: + current = self.sessions.get(sid) + if current is agent_session: + self.sessions.pop(sid, None) + removed = True + + if not removed: + # Session was replaced or revived while we were persisting. + continue + + # Cancel the running task if present once the snapshot is safe. + try: + if agent_session.task: + agent_session.task.cancel() + try: + await asyncio.wait_for(agent_session.task, timeout=5) + except Exception: + pass + except Exception: + pass + + logger.info("Unloaded inactive session %s due to inactivity", sid) + + async def _unload_inactive_sessions_loop(self) -> None: + """Background task: unload sessions that have been idle for too long. + + Sessions idle for more than 24 hours are persisted and removed from + memory to free capacity. The loop runs periodically (every 10 minutes). + """ + try: + while True: + await asyncio.sleep(INACTIVE_SESSION_SWEEP_INTERVAL_SECONDS) + await self._unload_inactive_sessions_once() + except asyncio.CancelledError: + return + async def persist_session_snapshot( self, agent_session: AgentSession, @@ -554,131 +696,169 @@ async def ensure_session_loaded( preload_sandbox: bool = True, ) -> AgentSession | None: """Return a live runtime session, lazily restoring it from Mongo.""" - async with self._lock: - existing = self.sessions.get(session_id) - if existing: - if self._can_access_session(existing, user_id): - self._update_hf_identity( - existing, - hf_token=hf_token, - hf_username=hf_username, - ) - self._restart_cpu_preload_if_token_recovered( - existing, - preload_sandbox=preload_sandbox, - ) - return existing - return None + submission_queue: asyncio.Queue = asyncio.Queue() + event_queue: asyncio.Queue = asyncio.Queue() + reserved_session: AgentSession | None = None + should_release_reserved_slot = False - store = self._store() - loaded = await store.load_session(session_id) - if not loaded: - return None + try: + async with self._lock: + existing = self.sessions.get(session_id) + if existing: + if getattr(existing, "session", None) is None: + return None + if self._can_access_session(existing, user_id): + self._update_hf_identity( + existing, + hf_token=hf_token, + hf_username=hf_username, + ) + self._restart_cpu_preload_if_token_recovered( + existing, + preload_sandbox=preload_sandbox, + ) + existing.last_access = datetime.utcnow() + return existing + return None - async with self._lock: - existing = self.sessions.get(session_id) - if existing: - if self._can_access_session(existing, user_id): - self._update_hf_identity( - existing, - hf_token=hf_token, + active_count = self.active_session_count + if active_count >= MAX_SESSIONS: + logger.warning( + "Cannot restore session %s: server at capacity (%d/%d)", + session_id, + active_count, + MAX_SESSIONS, + ) + return None + + reserved_session = self._make_reserved_session( + session_id=session_id, + user_id=user_id, hf_username=hf_username, + hf_token=hf_token, + submission_queue=submission_queue, ) - self._restart_cpu_preload_if_token_recovered( - existing, - preload_sandbox=preload_sandbox, - ) - return existing - return None + self.sessions[session_id] = reserved_session + should_release_reserved_slot = True - meta = loaded.get("metadata") or {} - owner = str(meta.get("user_id") or "") - if user_id != "dev" and owner != "dev" and owner != user_id: - return None + store = self._store() + loaded = await store.load_session(session_id) + if not loaded: + return None - await self._cleanup_persisted_sandbox( - session_id, - meta, - hf_token=hf_token, - ) + async with self._lock: + existing = self.sessions.get(session_id) + if existing is not reserved_session: + if existing and getattr(existing, "session", None) is not None: + if self._can_access_session(existing, user_id): + self._update_hf_identity( + existing, + hf_token=hf_token, + hf_username=hf_username, + ) + self._restart_cpu_preload_if_token_recovered( + existing, + preload_sandbox=preload_sandbox, + ) + existing.last_access = datetime.utcnow() + should_release_reserved_slot = False + return existing + return None - from litellm import Message + meta = loaded.get("metadata") or {} + owner = str(meta.get("user_id") or "") + if user_id != "dev" and owner != "dev" and owner != user_id: + return None - model = meta.get("model") or self.config.model_name - event_queue: asyncio.Queue = asyncio.Queue() - submission_queue: asyncio.Queue = asyncio.Queue() - tool_router, session = await asyncio.to_thread( - self._create_session_sync, - session_id=session_id, - user_id=owner or user_id, - hf_username=hf_username, - hf_token=hf_token, - model=model, - event_queue=event_queue, - notification_destinations=meta.get("notification_destinations") or [], - ) - - restored_messages: list[Message] = [] - for raw in loaded.get("messages") or []: - if not isinstance(raw, dict) or raw.get("role") == "system": - continue - try: - restored_messages.append(Message.model_validate(raw)) - except Exception as e: - logger.warning("Dropping malformed restored message: %s", e) - if restored_messages: - # Keep the freshly-rendered system prompt, then attach the durable - # non-system context so tools/date/user context stay current. - session.context_manager.items = [ - session.context_manager.items[0], - *restored_messages, - ] - - self._restore_pending_approval(session, meta.get("pending_approval") or []) - session.turn_count = int(meta.get("turn_count") or 0) - session.auto_approval_enabled = bool(meta.get("auto_approval_enabled", False)) - raw_cap = meta.get("auto_approval_cost_cap_usd") - session.auto_approval_cost_cap_usd = ( - float(raw_cap) if isinstance(raw_cap, int | float) else None - ) - session.auto_approval_estimated_spend_usd = float( - meta.get("auto_approval_estimated_spend_usd") or 0.0 - ) + await self._cleanup_persisted_sandbox( + session_id, + meta, + hf_token=hf_token, + ) - created_at = meta.get("created_at") - if not isinstance(created_at, datetime): - created_at = datetime.utcnow() + from litellm import Message - agent_session = AgentSession( - session_id=session_id, - session=session, - tool_router=tool_router, - submission_queue=submission_queue, - user_id=owner or user_id, - hf_username=hf_username, - hf_token=hf_token, - created_at=created_at, - is_active=True, - is_processing=False, - claude_counted=bool(meta.get("claude_counted")), - title=meta.get("title"), - ) - started = await self._start_agent_session( - agent_session=agent_session, - event_queue=event_queue, - tool_router=tool_router, - ) - if started is not agent_session: - self._update_hf_identity( - started, + model = meta.get("model") or self.config.model_name + tool_router, session = await asyncio.to_thread( + self._create_session_sync, + session_id=session_id, + user_id=owner or user_id, + hf_username=hf_username, hf_token=hf_token, + model=model, + event_queue=event_queue, + notification_destinations=meta.get("notification_destinations") or [], + ) + + restored_messages: list[Message] = [] + for raw in loaded.get("messages") or []: + if not isinstance(raw, dict) or raw.get("role") == "system": + continue + try: + restored_messages.append(Message.model_validate(raw)) + except Exception as e: + logger.warning("Dropping malformed restored message: %s", e) + if restored_messages: + # Keep the freshly-rendered system prompt, then attach the durable + # non-system context so tools/date/user context stay current. + session.context_manager.items = [ + session.context_manager.items[0], + *restored_messages, + ] + + self._restore_pending_approval(session, meta.get("pending_approval") or []) + session.turn_count = int(meta.get("turn_count") or 0) + session.auto_approval_enabled = bool( + meta.get("auto_approval_enabled", False) + ) + raw_cap = meta.get("auto_approval_cost_cap_usd") + session.auto_approval_cost_cap_usd = ( + float(raw_cap) if isinstance(raw_cap, int | float) else None + ) + session.auto_approval_estimated_spend_usd = float( + meta.get("auto_approval_estimated_spend_usd") or 0.0 + ) + + created_at = meta.get("created_at") + if not isinstance(created_at, datetime): + created_at = datetime.utcnow() + + agent_session = AgentSession( + session_id=session_id, + session=session, + tool_router=tool_router, + submission_queue=submission_queue, + user_id=owner or user_id, hf_username=hf_username, + hf_token=hf_token, + created_at=created_at, + is_active=True, + is_processing=False, + claude_counted=bool(meta.get("claude_counted")), + title=meta.get("title"), ) - return started - if preload_sandbox: - self._start_cpu_sandbox_preload(agent_session) - logger.info("Restored session %s for user %s", session_id, owner or user_id) - return agent_session + started = await self._start_agent_session( + agent_session=agent_session, + event_queue=event_queue, + tool_router=tool_router, + ) + if started is not agent_session: + self._update_hf_identity( + started, + hf_token=hf_token, + hf_username=hf_username, + ) + started.last_access = datetime.utcnow() + should_release_reserved_slot = False + return started + if preload_sandbox: + self._start_cpu_sandbox_preload(agent_session) + logger.info("Restored session %s for user %s", session_id, owner or user_id) + should_release_reserved_slot = False + return agent_session + finally: + if should_release_reserved_slot: + await self._release_reserved_session_slot(session_id, reserved_session) async def create_session( self, @@ -706,7 +886,11 @@ async def create_session( SessionCapacityError: If the server or user has reached the maximum number of concurrent sessions. """ - # ── Capacity checks ────────────────────────────────────────── + # ── Capacity checks & reservation ─────────────────────────── + # Create lightweight queues up-front (non-blocking). + submission_queue: asyncio.Queue = asyncio.Queue() + event_queue: asyncio.Queue = asyncio.Queue() + async with self._lock: active_count = self.active_session_count if active_count >= MAX_SESSIONS: @@ -724,11 +908,22 @@ async def create_session( error_type="per_user", ) - session_id = str(uuid.uuid4()) - - # Create queues for this session - submission_queue: asyncio.Queue = asyncio.Queue() - event_queue: asyncio.Queue = asyncio.Queue() + session_id = str(uuid.uuid4()) + # Reserve the slot with a placeholder AgentSession so concurrent + # creators cannot exceed MAX_SESSIONS. The placeholder has no + # real Session/tool_router yet (session is None) but counts + # towards `active_session_count` because `is_active` is True. + placeholder = AgentSession( + session_id=session_id, + session=None, + tool_router=None, + submission_queue=submission_queue, + user_id=user_id, + hf_username=hf_username, + hf_token=hf_token, + is_active=True, + ) + self.sessions[session_id] = placeholder # Run blocking constructors in a thread to keep the event loop responsive. tool_router, session = await asyncio.to_thread( @@ -741,7 +936,8 @@ async def create_session( event_queue=event_queue, ) - # Create wrapper + # Create wrapper with the real session resources and replace the + # placeholder in _start_agent_session. agent_session = AgentSession( session_id=session_id, session=session, diff --git a/tests/unit/test_session_capacity.py b/tests/unit/test_session_capacity.py new file mode 100644 index 00000000..16a81573 --- /dev/null +++ b/tests/unit/test_session_capacity.py @@ -0,0 +1,245 @@ +import asyncio +import importlib +import sys +from datetime import datetime, timedelta +from pathlib import Path +from types import SimpleNamespace + +import pytest +import types + +_BACKEND_DIR = Path(__file__).resolve().parent.parent.parent / "backend" + + +@pytest.fixture +def session_manager_module(monkeypatch): + """Import backend.session_manager with temporary dependency stubs. + + The stubs are inserted with monkeypatch so they are restored after the test, + and the imported session_manager module is removed from sys.modules to avoid + leaking the stubbed import state into other tests. + """ + with monkeypatch.context() as m: + m.syspath_prepend(str(_BACKEND_DIR)) + + litellm_stub = types.ModuleType("litellm") + + class _DummyMessage: + @staticmethod + def model_validate(raw): + return SimpleNamespace(**raw) + + setattr(litellm_stub, "Message", _DummyMessage) + m.setitem(sys.modules, "litellm", litellm_stub) + + fastmcp_stub = types.ModuleType("fastmcp") + + class _DummyClient: + pass + + setattr(fastmcp_stub, "Client", _DummyClient) + m.setitem(sys.modules, "fastmcp", fastmcp_stub) + + m.setitem(sys.modules, "thefuzz", types.ModuleType("thefuzz")) + m.setitem( + sys.modules, + "huggingface_hub", + types.ModuleType("huggingface_hub"), + ) + + module = importlib.import_module("session_manager") + yield module + + sys.modules.pop("session_manager", None) + + +@pytest.mark.asyncio +async def test_restore_denied_when_at_capacity(session_manager_module, caplog): + manager = session_manager_module.SessionManager() + for i in range(session_manager_module.MAX_SESSIONS): + manager.sessions[str(i)] = session_manager_module.AgentSession( + session_id=str(i), + session=object(), + tool_router=None, + submission_queue=asyncio.Queue(), + is_active=True, + ) + + class DummyStore: + enabled = True + + async def load_session(self, sid): + return {"metadata": {"user_id": "test", "model": "gpt"}, "messages": []} + + manager.persistence_store = DummyStore() + + caplog.set_level("WARNING") + res = await manager.ensure_session_loaded("restored", user_id="test") + assert res is None + assert any("Cannot restore session" in rec.message for rec in caplog.records) + + +@pytest.mark.asyncio +async def test_restore_allowed_under_capacity(session_manager_module, monkeypatch): + manager = session_manager_module.SessionManager() + manager.sessions.clear() + + class DummyStore: + enabled = True + + async def load_session(self, sid): + return {"metadata": {"user_id": "test", "model": "gpt"}, "messages": []} + + manager.persistence_store = DummyStore() + + def fake_create_session_sync( + *, session_id, user_id, hf_username, hf_token, model, event_queue, notification_destinations + ): + class SimpleSession: + def __init__(self): + self.context_manager = SimpleNamespace(items=[object()]) + self.pending_approval = None + self.turn_count = 0 + self.notification_destinations = [] + self.config = SimpleNamespace(model_name=model) + + return None, SimpleSession() + + monkeypatch.setattr(manager, "_create_session_sync", fake_create_session_sync) + + async def fake_start_agent_session(*, agent_session, event_queue, tool_router): + async with manager._lock: + manager.sessions[agent_session.session_id] = agent_session + return agent_session + + monkeypatch.setattr(manager, "_start_agent_session", fake_start_agent_session) + + res = await manager.ensure_session_loaded("restored", user_id="test") + assert res is not None + assert res.session is not None + + +@pytest.mark.asyncio +async def test_restore_rolls_back_placeholder_on_load_failure(session_manager_module): + manager = session_manager_module.SessionManager() + + class FailingStore: + enabled = True + + async def load_session(self, sid): + raise RuntimeError("load failed") + + manager.persistence_store = FailingStore() + + with pytest.raises(RuntimeError, match="load failed"): + await manager.ensure_session_loaded("restored", user_id="test") + + assert manager.sessions == {} + + +@pytest.mark.asyncio +async def test_create_session_rolls_back_placeholder_on_failure( + session_manager_module, monkeypatch +): + manager = session_manager_module.SessionManager() + + def fake_create_session_sync(**kwargs): + raise RuntimeError("boom") + + monkeypatch.setattr(manager, "_create_session_sync", fake_create_session_sync) + + with pytest.raises(RuntimeError, match="boom"): + await manager.create_session(user_id="u1") + + assert manager.sessions == {} + + +@pytest.mark.asyncio +async def test_unload_inactive_session_persists_and_removes(session_manager_module): + manager = session_manager_module.SessionManager() + persisted = [] + + async def fake_persist_session_snapshot(agent_session, **kwargs): + persisted.append((agent_session.session_id, kwargs)) + + manager.persist_session_snapshot = fake_persist_session_snapshot + + stale_session = session_manager_module.AgentSession( + session_id="stale", + session=SimpleNamespace( + pending_approval=None, + notification_destinations=[], + turn_count=0, + config=SimpleNamespace(model_name="gpt"), + ), + tool_router=None, + submission_queue=asyncio.Queue(), + is_active=True, + last_access=datetime.utcnow() + - session_manager_module.INACTIVE_SESSION_IDLE_THRESHOLD + - timedelta(seconds=1), + ) + manager.sessions["stale"] = stale_session + + await manager._unload_inactive_sessions_once() + + assert "stale" not in manager.sessions + assert persisted == [ + ( + "stale", + {"runtime_state": "idle", "status": "inactive"}, + ) + ] + + +@pytest.mark.asyncio +async def test_unload_inactive_sessions_skips_active_and_processing( + session_manager_module, +): + manager = session_manager_module.SessionManager() + persisted = [] + + async def fake_persist_session_snapshot(agent_session, **kwargs): + persisted.append(agent_session.session_id) + + manager.persist_session_snapshot = fake_persist_session_snapshot + + active_session = session_manager_module.AgentSession( + session_id="active", + session=SimpleNamespace( + pending_approval=None, + notification_destinations=[], + turn_count=0, + config=SimpleNamespace(model_name="gpt"), + ), + tool_router=None, + submission_queue=asyncio.Queue(), + is_active=True, + last_access=datetime.utcnow() + - session_manager_module.INACTIVE_SESSION_IDLE_THRESHOLD + - timedelta(seconds=1), + ) + processing_session = session_manager_module.AgentSession( + session_id="processing", + session=SimpleNamespace( + pending_approval=None, + notification_destinations=[], + turn_count=0, + config=SimpleNamespace(model_name="gpt"), + ), + tool_router=None, + submission_queue=asyncio.Queue(), + is_active=True, + is_processing=True, + last_access=datetime.utcnow() + - session_manager_module.INACTIVE_SESSION_IDLE_THRESHOLD + - timedelta(seconds=1), + ) + manager.sessions["active"] = active_session + manager.sessions["processing"] = processing_session + + await manager._unload_inactive_sessions_once() + + assert "active" in manager.sessions + assert "processing" in manager.sessions + assert persisted == []