From 18786d9ad63e6fc9a57f336e7d471278f0eea5fb Mon Sep 17 00:00:00 2001 From: Om Shrivastava Date: Wed, 20 May 2026 16:10:05 +0530 Subject: [PATCH 1/4] Track last access time and unload inactive sessions Added last_access attribute to track session activity and implemented a background task to unload inactive sessions after 24 hours of idleness. --- backend/session_manager.py | 128 ++++++++++++++++++++++++++++++++++--- 1 file changed, 118 insertions(+), 10 deletions(-) diff --git a/backend/session_manager.py b/backend/session_manager.py index 3c992c9c..fbf73eb3 100644 --- a/backend/session_manager.py +++ b/backend/session_manager.py @@ -6,7 +6,7 @@ import os import uuid from dataclasses import dataclass, field -from datetime import datetime +from datetime import datetime, timedelta from pathlib import Path from typing import Any, Optional @@ -100,6 +100,8 @@ class AgentSession: is_processing: bool = False # True while a submission is being executed broadcaster: Any = None title: str | None = None + # Last time this session was accessed (for idle-unload logic) + last_access: datetime = field(default_factory=datetime.utcnow) # True once this session has been counted against the user's daily # Claude quota. Guards double-counting when the user re-selects an # Anthropic model mid-session. @@ -141,10 +143,19 @@ async def start(self) -> None: self.persistence_store = get_session_store() await self.persistence_store.init() await self.messaging_gateway.start() + # Start background cleanup task to unload long-idle sessions. + self._cleanup_task = asyncio.create_task(self._unload_inactive_sessions_loop()) async def close(self) -> None: """Flush and close shared background resources.""" await self._cleanup_all_sandboxes_on_close() + # Cancel cleanup task + if getattr(self, "_cleanup_task", None): + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass await self.messaging_gateway.close() if self.persistence_store is not None: await self.persistence_store.close() @@ -327,8 +338,17 @@ async def _start_agent_session( async with self._lock: existing = self.sessions.get(agent_session.session_id) if existing: - return existing - self.sessions[agent_session.session_id] = agent_session + # If an earlier coroutine reserved the slot with a placeholder + # AgentSession (session is None), replace it with the real + # agent_session. Otherwise return the existing live session. + if getattr(existing, "session", None) is None: + self.sessions[agent_session.session_id] = agent_session + else: + # Update access time for the existing live session + existing.last_access = datetime.utcnow() + return existing + else: + self.sessions[agent_session.session_id] = agent_session task = asyncio.create_task( self._run_session( @@ -339,6 +359,8 @@ async def _start_agent_session( ) ) agent_session.task = task + # mark last access time when the session has been started + agent_session.last_access = datetime.utcnow() return agent_session @staticmethod @@ -494,6 +516,61 @@ async def _cleanup_persisted_sandbox( last_err, ) + async def _unload_inactive_sessions_loop(self) -> None: + """Background task: unload sessions that have been idle for too long. + + Sessions idle for more than 24 hours are persisted and removed from + memory to free capacity. The loop runs periodically (every 10 minutes). + """ + try: + while True: + await asyncio.sleep(600) # 10 minutes + now = datetime.utcnow() + to_unload: list[tuple[str, AgentSession]] = [] + async with self._lock: + for sid, agent_session in list(self.sessions.items()): + try: + if not agent_session.is_active: + continue + if getattr(agent_session, "is_processing", False): + continue + last = getattr(agent_session, "last_access", agent_session.created_at) + if now - last > timedelta(hours=24): + # Mark inactive and remove from registry under lock + agent_session.is_active = False + to_unload.append((sid, agent_session)) + del self.sessions[sid] + except Exception as e: + logger.debug("Skipping unload check for %s: %s", sid, e) + + for sid, agent_session in to_unload: + try: + await self.persist_session_snapshot( + agent_session, + runtime_state=self._runtime_state(agent_session), + status="inactive", + ) + except Exception as e: + logger.warning( + "Failed to persist snapshot before unloading %s: %s", + sid, + e, + ) + # Cancel the running task if present + try: + if agent_session.task: + agent_session.task.cancel() + try: + await asyncio.wait_for(agent_session.task, timeout=5) + except Exception: + pass + except Exception: + pass + + logger.info("Unloaded inactive session %s due to inactivity", sid) + except asyncio.CancelledError: + return + async def persist_session_snapshot( self, agent_session: AgentSession, @@ -567,9 +644,22 @@ async def ensure_session_loaded( existing, preload_sandbox=preload_sandbox, ) + existing.last_access = datetime.utcnow() return existing return None + # Check capacity before restoring from persistence + async with self._lock: + active_count = self.active_session_count + if active_count >= MAX_SESSIONS: + logger.warning( + "Cannot restore session %s: server at capacity (%d/%d)", + session_id, + active_count, + MAX_SESSIONS, + ) + return None + store = self._store() loaded = await store.load_session(session_id) if not loaded: @@ -588,6 +678,7 @@ async def ensure_session_loaded( existing, preload_sandbox=preload_sandbox, ) + existing.last_access = datetime.utcnow() return existing return None @@ -674,6 +765,7 @@ async def ensure_session_loaded( hf_token=hf_token, hf_username=hf_username, ) + started.last_access = datetime.utcnow() return started if preload_sandbox: self._start_cpu_sandbox_preload(agent_session) @@ -706,7 +798,11 @@ async def create_session( SessionCapacityError: If the server or user has reached the maximum number of concurrent sessions. """ - # ── Capacity checks ────────────────────────────────────────── + # ── Capacity checks & reservation ─────────────────────────── + # Create lightweight queues up-front (non-blocking). + submission_queue: asyncio.Queue = asyncio.Queue() + event_queue: asyncio.Queue = asyncio.Queue() + async with self._lock: active_count = self.active_session_count if active_count >= MAX_SESSIONS: @@ -724,11 +820,22 @@ async def create_session( error_type="per_user", ) - session_id = str(uuid.uuid4()) - - # Create queues for this session - submission_queue: asyncio.Queue = asyncio.Queue() - event_queue: asyncio.Queue = asyncio.Queue() + session_id = str(uuid.uuid4()) + # Reserve the slot with a placeholder AgentSession so concurrent + # creators cannot exceed MAX_SESSIONS. The placeholder has no + # real Session/tool_router yet (session is None) but counts + # towards `active_session_count` because `is_active` is True. + placeholder = AgentSession( + session_id=session_id, + session=None, + tool_router=None, + submission_queue=submission_queue, + user_id=user_id, + hf_username=hf_username, + hf_token=hf_token, + is_active=True, + ) + self.sessions[session_id] = placeholder # Run blocking constructors in a thread to keep the event loop responsive. tool_router, session = await asyncio.to_thread( @@ -741,7 +848,8 @@ async def create_session( event_queue=event_queue, ) - # Create wrapper + # Create wrapper with the real session resources and replace the + # placeholder in _start_agent_session. agent_session = AgentSession( session_id=session_id, session=session, From 635f57c7eb109d911bfc2ea45fdfa4f3ed7af0db Mon Sep 17 00:00:00 2001 From: Om Shrivastava Date: Wed, 20 May 2026 16:12:36 +0530 Subject: [PATCH 2/4] Created test_session_capacity --- tests/unit/test_session_capacity.py | 90 +++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 tests/unit/test_session_capacity.py diff --git a/tests/unit/test_session_capacity.py b/tests/unit/test_session_capacity.py new file mode 100644 index 00000000..41ca048a --- /dev/null +++ b/tests/unit/test_session_capacity.py @@ -0,0 +1,90 @@ +import asyncio +import sys +from pathlib import Path +from types import SimpleNamespace + +import pytest +import types + +# Prevent importing heavy third-party modules when importing the backend module. +for _mod in ("litellm", "fastmcp", "thefuzz", "huggingface_hub"): + if _mod not in sys.modules: + m = types.ModuleType(_mod) + # fastmcp is imported as `from fastmcp import Client` in some codepaths + if _mod == "fastmcp": + class _DummyClient: + pass + + setattr(m, "Client", _DummyClient) + sys.modules[_mod] = m + +_BACKEND_DIR = Path(__file__).resolve().parent.parent.parent / "backend" +if str(_BACKEND_DIR) not in sys.path: + sys.path.insert(0, str(_BACKEND_DIR)) + +from session_manager import SessionManager, AgentSession, MAX_SESSIONS + + +@pytest.mark.asyncio +async def test_restore_denied_when_at_capacity(caplog): + manager = SessionManager() + # Fill in-memory sessions up to MAX_SESSIONS + for i in range(MAX_SESSIONS): + manager.sessions[str(i)] = AgentSession( + session_id=str(i), + session=object(), + tool_router=None, + submission_queue=asyncio.Queue(), + is_active=True, + ) + + class DummyStore: + enabled = True + + async def load_session(self, sid): + return {"metadata": {"user_id": "test", "model": "gpt"}, "messages": []} + + manager.persistence_store = DummyStore() + + caplog.set_level("WARNING") + res = await manager.ensure_session_loaded("restored", user_id="test") + assert res is None + assert any("Cannot restore session" in rec.message for rec in caplog.records) + + +@pytest.mark.asyncio +async def test_restore_allowed_under_capacity(monkeypatch): + manager = SessionManager() + manager.sessions.clear() + + class DummyStore: + enabled = True + + async def load_session(self, sid): + return {"metadata": {"user_id": "test", "model": "gpt"}, "messages": []} + + manager.persistence_store = DummyStore() + + # Replace the heavy sync constructor with a lightweight fake + def fake_create_session_sync(*, session_id, user_id, hf_username, hf_token, model, event_queue, notification_destinations): + class SimpleSession: + def __init__(self): + self.context_manager = SimpleNamespace(items=[object()]) + self.pending_approval = None + self.turn_count = 0 + + return None, SimpleSession() + + monkeypatch.setattr(manager, "_create_session_sync", fake_create_session_sync) + + # Fake _start_agent_session to register the session without starting tasks + async def fake_start_agent_session(*, agent_session, event_queue, tool_router): + async with manager._lock: + manager.sessions[agent_session.session_id] = agent_session + return agent_session + + monkeypatch.setattr(manager, "_start_agent_session", fake_start_agent_session) + + res = await manager.ensure_session_loaded("restored", user_id="test") + assert res is not None + assert res.session is not None \ No newline at end of file From 6663d0425200af3ee4e5f3a408f60b250fadc498 Mon Sep 17 00:00:00 2001 From: Om Shrivastava Date: Wed, 20 May 2026 16:33:43 +0530 Subject: [PATCH 3/4] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- backend/session_manager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backend/session_manager.py b/backend/session_manager.py index fbf73eb3..5b63b23c 100644 --- a/backend/session_manager.py +++ b/backend/session_manager.py @@ -156,6 +156,8 @@ async def close(self) -> None: await self._cleanup_task except asyncio.CancelledError: pass + except Exception: + logger.exception("Cleanup task failed during shutdown") await self.messaging_gateway.close() if self.persistence_store is not None: await self.persistence_store.close() From 304c10091653932f7dceee2c9762a28f1d5794b3 Mon Sep 17 00:00:00 2001 From: Om Shrivastava Date: Wed, 20 May 2026 16:41:03 +0530 Subject: [PATCH 4/4] backend: enforce session caps and add idle unload tests --- backend/session_manager.py | 426 +++++++++++++++++----------- tests/unit/test_session_capacity.py | 205 +++++++++++-- 2 files changed, 436 insertions(+), 195 deletions(-) diff --git a/backend/session_manager.py b/backend/session_manager.py index 5b63b23c..a72149af 100644 --- a/backend/session_manager.py +++ b/backend/session_manager.py @@ -210,6 +210,40 @@ def _create_session_sync( logger.info("Session initialized in %.2fs", t1 - t0) return tool_router, session + def _make_reserved_session( + self, + *, + session_id: str, + user_id: str, + hf_username: str | None, + hf_token: str | None, + submission_queue: asyncio.Queue, + ) -> AgentSession: + """Create a placeholder session that reserves capacity under lock.""" + return AgentSession( + session_id=session_id, + session=None, + tool_router=None, + submission_queue=submission_queue, + user_id=user_id, + hf_username=hf_username, + hf_token=hf_token, + is_active=True, + ) + + async def _release_reserved_session_slot( + self, + session_id: str, + reserved_session: AgentSession | None = None, + ) -> None: + """Remove a reserved placeholder if it is still present.""" + async with self._lock: + current = self.sessions.get(session_id) + if current is None: + return + if current is reserved_session or getattr(current, "session", None) is None: + self.sessions.pop(session_id, None) + def _serialize_messages(self, session: Session) -> list[dict[str, Any]]: return [msg.model_dump(mode="json") for msg in session.context_manager.items] @@ -518,6 +552,77 @@ async def _cleanup_persisted_sandbox( last_err, ) + async def _unload_inactive_sessions_once(self) -> None: + """Run one idle-session sweep. + + Extracted so tests can exercise the cleanup behavior without waiting + for the background loop's sleep interval. + """ + now = datetime.utcnow() + to_unload: list[tuple[str, AgentSession]] = [] + async with self._lock: + for sid, agent_session in list(self.sessions.items()): + try: + if not agent_session.is_active: + continue + if getattr(agent_session, "is_processing", False): + continue + last = getattr( + agent_session, + "last_access", + agent_session.created_at, + ) + if now - last > INACTIVE_SESSION_IDLE_THRESHOLD: + # Mark inactive, but keep it resident until the + # snapshot has been persisted successfully. + agent_session.is_active = False + to_unload.append((sid, agent_session)) + except Exception as e: + logger.debug("Skipping unload check for %s: %s", sid, e) + + for sid, agent_session in to_unload: + try: + await self.persist_session_snapshot( + agent_session, + runtime_state=self._runtime_state(agent_session), + status="inactive", + ) + except Exception as e: + logger.warning( + "Failed to persist snapshot before unloading %s: %s", + sid, + e, + ) + # Keep the session in memory so the next cleanup cycle + # can retry persistence and so callers can still inspect + # the session state. + agent_session.is_active = True + continue + + removed = False + async with self._lock: + current = self.sessions.get(sid) + if current is agent_session: + self.sessions.pop(sid, None) + removed = True + + if not removed: + # Session was replaced or revived while we were persisting. + continue + + # Cancel the running task if present once the snapshot is safe. + try: + if agent_session.task: + agent_session.task.cancel() + try: + await asyncio.wait_for(agent_session.task, timeout=5) + except Exception: + pass + except Exception: + pass + + logger.info("Unloaded inactive session %s due to inactivity", sid) + async def _unload_inactive_sessions_loop(self) -> None: """Background task: unload sessions that have been idle for too long. @@ -526,50 +631,8 @@ async def _unload_inactive_sessions_loop(self) -> None: """ try: while True: - await asyncio.sleep(600) # 10 minutes - now = datetime.utcnow() - to_unload: list[tuple[str, AgentSession]] = [] - async with self._lock: - for sid, agent_session in list(self.sessions.items()): - try: - if not agent_session.is_active: - continue - if getattr(agent_session, "is_processing", False): - continue - last = getattr(agent_session, "last_access", agent_session.created_at) - if now - last > timedelta(hours=24): - # Mark inactive and remove from registry under lock - agent_session.is_active = False - to_unload.append((sid, agent_session)) - del self.sessions[sid] - except Exception as e: - logger.debug("Skipping unload check for %s: %s", sid, e) - - for sid, agent_session in to_unload: - try: - await self.persist_session_snapshot( - agent_session, - runtime_state=self._runtime_state(agent_session), - status="inactive", - ) - except Exception as e: - logger.warning( - "Failed to persist snapshot before unloading %s: %s", - sid, - e, - ) - # Cancel the running task if present - try: - if agent_session.task: - agent_session.task.cancel() - try: - await asyncio.wait_for(agent_session.task, timeout=5) - except Exception: - pass - except Exception: - pass - - logger.info("Unloaded inactive session %s due to inactivity", sid) + await asyncio.sleep(INACTIVE_SESSION_SWEEP_INTERVAL_SECONDS) + await self._unload_inactive_sessions_once() except asyncio.CancelledError: return @@ -633,146 +696,169 @@ async def ensure_session_loaded( preload_sandbox: bool = True, ) -> AgentSession | None: """Return a live runtime session, lazily restoring it from Mongo.""" - async with self._lock: - existing = self.sessions.get(session_id) - if existing: - if self._can_access_session(existing, user_id): - self._update_hf_identity( - existing, - hf_token=hf_token, - hf_username=hf_username, - ) - self._restart_cpu_preload_if_token_recovered( - existing, - preload_sandbox=preload_sandbox, - ) - existing.last_access = datetime.utcnow() - return existing - return None + submission_queue: asyncio.Queue = asyncio.Queue() + event_queue: asyncio.Queue = asyncio.Queue() + reserved_session: AgentSession | None = None + should_release_reserved_slot = False - # Check capacity before restoring from persistence - async with self._lock: - active_count = self.active_session_count - if active_count >= MAX_SESSIONS: - logger.warning( - "Cannot restore session %s: server at capacity (%d/%d)", - session_id, - active_count, - MAX_SESSIONS, - ) - return None + try: + async with self._lock: + existing = self.sessions.get(session_id) + if existing: + if getattr(existing, "session", None) is None: + return None + if self._can_access_session(existing, user_id): + self._update_hf_identity( + existing, + hf_token=hf_token, + hf_username=hf_username, + ) + self._restart_cpu_preload_if_token_recovered( + existing, + preload_sandbox=preload_sandbox, + ) + existing.last_access = datetime.utcnow() + return existing + return None - store = self._store() - loaded = await store.load_session(session_id) - if not loaded: - return None + active_count = self.active_session_count + if active_count >= MAX_SESSIONS: + logger.warning( + "Cannot restore session %s: server at capacity (%d/%d)", + session_id, + active_count, + MAX_SESSIONS, + ) + return None - async with self._lock: - existing = self.sessions.get(session_id) - if existing: - if self._can_access_session(existing, user_id): - self._update_hf_identity( - existing, - hf_token=hf_token, + reserved_session = self._make_reserved_session( + session_id=session_id, + user_id=user_id, hf_username=hf_username, + hf_token=hf_token, + submission_queue=submission_queue, ) - self._restart_cpu_preload_if_token_recovered( - existing, - preload_sandbox=preload_sandbox, - ) - existing.last_access = datetime.utcnow() - return existing - return None + self.sessions[session_id] = reserved_session + should_release_reserved_slot = True - meta = loaded.get("metadata") or {} - owner = str(meta.get("user_id") or "") - if user_id != "dev" and owner != "dev" and owner != user_id: - return None + store = self._store() + loaded = await store.load_session(session_id) + if not loaded: + return None - await self._cleanup_persisted_sandbox( - session_id, - meta, - hf_token=hf_token, - ) + async with self._lock: + existing = self.sessions.get(session_id) + if existing is not reserved_session: + if existing and getattr(existing, "session", None) is not None: + if self._can_access_session(existing, user_id): + self._update_hf_identity( + existing, + hf_token=hf_token, + hf_username=hf_username, + ) + self._restart_cpu_preload_if_token_recovered( + existing, + preload_sandbox=preload_sandbox, + ) + existing.last_access = datetime.utcnow() + should_release_reserved_slot = False + return existing + return None + + meta = loaded.get("metadata") or {} + owner = str(meta.get("user_id") or "") + if user_id != "dev" and owner != "dev" and owner != user_id: + return None - from litellm import Message + await self._cleanup_persisted_sandbox( + session_id, + meta, + hf_token=hf_token, + ) - model = meta.get("model") or self.config.model_name - event_queue: asyncio.Queue = asyncio.Queue() - submission_queue: asyncio.Queue = asyncio.Queue() - tool_router, session = await asyncio.to_thread( - self._create_session_sync, - session_id=session_id, - user_id=owner or user_id, - hf_username=hf_username, - hf_token=hf_token, - model=model, - event_queue=event_queue, - notification_destinations=meta.get("notification_destinations") or [], - ) + from litellm import Message - restored_messages: list[Message] = [] - for raw in loaded.get("messages") or []: - if not isinstance(raw, dict) or raw.get("role") == "system": - continue - try: - restored_messages.append(Message.model_validate(raw)) - except Exception as e: - logger.warning("Dropping malformed restored message: %s", e) - if restored_messages: - # Keep the freshly-rendered system prompt, then attach the durable - # non-system context so tools/date/user context stay current. - session.context_manager.items = [ - session.context_manager.items[0], - *restored_messages, - ] - - self._restore_pending_approval(session, meta.get("pending_approval") or []) - session.turn_count = int(meta.get("turn_count") or 0) - session.auto_approval_enabled = bool(meta.get("auto_approval_enabled", False)) - raw_cap = meta.get("auto_approval_cost_cap_usd") - session.auto_approval_cost_cap_usd = ( - float(raw_cap) if isinstance(raw_cap, int | float) else None - ) - session.auto_approval_estimated_spend_usd = float( - meta.get("auto_approval_estimated_spend_usd") or 0.0 - ) + model = meta.get("model") or self.config.model_name + tool_router, session = await asyncio.to_thread( + self._create_session_sync, + session_id=session_id, + user_id=owner or user_id, + hf_username=hf_username, + hf_token=hf_token, + model=model, + event_queue=event_queue, + notification_destinations=meta.get("notification_destinations") or [], + ) - created_at = meta.get("created_at") - if not isinstance(created_at, datetime): - created_at = datetime.utcnow() + restored_messages: list[Message] = [] + for raw in loaded.get("messages") or []: + if not isinstance(raw, dict) or raw.get("role") == "system": + continue + try: + restored_messages.append(Message.model_validate(raw)) + except Exception as e: + logger.warning("Dropping malformed restored message: %s", e) + if restored_messages: + # Keep the freshly-rendered system prompt, then attach the durable + # non-system context so tools/date/user context stay current. + session.context_manager.items = [ + session.context_manager.items[0], + *restored_messages, + ] + + self._restore_pending_approval(session, meta.get("pending_approval") or []) + session.turn_count = int(meta.get("turn_count") or 0) + session.auto_approval_enabled = bool( + meta.get("auto_approval_enabled", False) + ) + raw_cap = meta.get("auto_approval_cost_cap_usd") + session.auto_approval_cost_cap_usd = ( + float(raw_cap) if isinstance(raw_cap, int | float) else None + ) + session.auto_approval_estimated_spend_usd = float( + meta.get("auto_approval_estimated_spend_usd") or 0.0 + ) - agent_session = AgentSession( - session_id=session_id, - session=session, - tool_router=tool_router, - submission_queue=submission_queue, - user_id=owner or user_id, - hf_username=hf_username, - hf_token=hf_token, - created_at=created_at, - is_active=True, - is_processing=False, - claude_counted=bool(meta.get("claude_counted")), - title=meta.get("title"), - ) - started = await self._start_agent_session( - agent_session=agent_session, - event_queue=event_queue, - tool_router=tool_router, - ) - if started is not agent_session: - self._update_hf_identity( - started, - hf_token=hf_token, + created_at = meta.get("created_at") + if not isinstance(created_at, datetime): + created_at = datetime.utcnow() + + agent_session = AgentSession( + session_id=session_id, + session=session, + tool_router=tool_router, + submission_queue=submission_queue, + user_id=owner or user_id, hf_username=hf_username, + hf_token=hf_token, + created_at=created_at, + is_active=True, + is_processing=False, + claude_counted=bool(meta.get("claude_counted")), + title=meta.get("title"), ) - started.last_access = datetime.utcnow() - return started - if preload_sandbox: - self._start_cpu_sandbox_preload(agent_session) - logger.info("Restored session %s for user %s", session_id, owner or user_id) - return agent_session + started = await self._start_agent_session( + agent_session=agent_session, + event_queue=event_queue, + tool_router=tool_router, + ) + if started is not agent_session: + self._update_hf_identity( + started, + hf_token=hf_token, + hf_username=hf_username, + ) + started.last_access = datetime.utcnow() + should_release_reserved_slot = False + return started + if preload_sandbox: + self._start_cpu_sandbox_preload(agent_session) + logger.info("Restored session %s for user %s", session_id, owner or user_id) + should_release_reserved_slot = False + return agent_session + finally: + if should_release_reserved_slot: + await self._release_reserved_session_slot(session_id, reserved_session) async def create_session( self, diff --git a/tests/unit/test_session_capacity.py b/tests/unit/test_session_capacity.py index 41ca048a..16a81573 100644 --- a/tests/unit/test_session_capacity.py +++ b/tests/unit/test_session_capacity.py @@ -1,36 +1,63 @@ import asyncio +import importlib import sys +from datetime import datetime, timedelta from pathlib import Path from types import SimpleNamespace import pytest import types -# Prevent importing heavy third-party modules when importing the backend module. -for _mod in ("litellm", "fastmcp", "thefuzz", "huggingface_hub"): - if _mod not in sys.modules: - m = types.ModuleType(_mod) - # fastmcp is imported as `from fastmcp import Client` in some codepaths - if _mod == "fastmcp": - class _DummyClient: - pass +_BACKEND_DIR = Path(__file__).resolve().parent.parent.parent / "backend" - setattr(m, "Client", _DummyClient) - sys.modules[_mod] = m -_BACKEND_DIR = Path(__file__).resolve().parent.parent.parent / "backend" -if str(_BACKEND_DIR) not in sys.path: - sys.path.insert(0, str(_BACKEND_DIR)) +@pytest.fixture +def session_manager_module(monkeypatch): + """Import backend.session_manager with temporary dependency stubs. + + The stubs are inserted with monkeypatch so they are restored after the test, + and the imported session_manager module is removed from sys.modules to avoid + leaking the stubbed import state into other tests. + """ + with monkeypatch.context() as m: + m.syspath_prepend(str(_BACKEND_DIR)) + + litellm_stub = types.ModuleType("litellm") + + class _DummyMessage: + @staticmethod + def model_validate(raw): + return SimpleNamespace(**raw) -from session_manager import SessionManager, AgentSession, MAX_SESSIONS + setattr(litellm_stub, "Message", _DummyMessage) + m.setitem(sys.modules, "litellm", litellm_stub) + + fastmcp_stub = types.ModuleType("fastmcp") + + class _DummyClient: + pass + + setattr(fastmcp_stub, "Client", _DummyClient) + m.setitem(sys.modules, "fastmcp", fastmcp_stub) + + m.setitem(sys.modules, "thefuzz", types.ModuleType("thefuzz")) + m.setitem( + sys.modules, + "huggingface_hub", + types.ModuleType("huggingface_hub"), + ) + + module = importlib.import_module("session_manager") + yield module + + sys.modules.pop("session_manager", None) @pytest.mark.asyncio -async def test_restore_denied_when_at_capacity(caplog): - manager = SessionManager() - # Fill in-memory sessions up to MAX_SESSIONS - for i in range(MAX_SESSIONS): - manager.sessions[str(i)] = AgentSession( +async def test_restore_denied_when_at_capacity(session_manager_module, caplog): + manager = session_manager_module.SessionManager() + for i in range(session_manager_module.MAX_SESSIONS): + manager.sessions[str(i)] = session_manager_module.AgentSession( session_id=str(i), session=object(), tool_router=None, @@ -53,8 +80,8 @@ async def load_session(self, sid): @pytest.mark.asyncio -async def test_restore_allowed_under_capacity(monkeypatch): - manager = SessionManager() +async def test_restore_allowed_under_capacity(session_manager_module, monkeypatch): + manager = session_manager_module.SessionManager() manager.sessions.clear() class DummyStore: @@ -65,19 +92,21 @@ async def load_session(self, sid): manager.persistence_store = DummyStore() - # Replace the heavy sync constructor with a lightweight fake - def fake_create_session_sync(*, session_id, user_id, hf_username, hf_token, model, event_queue, notification_destinations): + def fake_create_session_sync( + *, session_id, user_id, hf_username, hf_token, model, event_queue, notification_destinations + ): class SimpleSession: def __init__(self): self.context_manager = SimpleNamespace(items=[object()]) self.pending_approval = None self.turn_count = 0 + self.notification_destinations = [] + self.config = SimpleNamespace(model_name=model) return None, SimpleSession() monkeypatch.setattr(manager, "_create_session_sync", fake_create_session_sync) - # Fake _start_agent_session to register the session without starting tasks async def fake_start_agent_session(*, agent_session, event_queue, tool_router): async with manager._lock: manager.sessions[agent_session.session_id] = agent_session @@ -87,4 +116,130 @@ async def fake_start_agent_session(*, agent_session, event_queue, tool_router): res = await manager.ensure_session_loaded("restored", user_id="test") assert res is not None - assert res.session is not None \ No newline at end of file + assert res.session is not None + + +@pytest.mark.asyncio +async def test_restore_rolls_back_placeholder_on_load_failure(session_manager_module): + manager = session_manager_module.SessionManager() + + class FailingStore: + enabled = True + + async def load_session(self, sid): + raise RuntimeError("load failed") + + manager.persistence_store = FailingStore() + + with pytest.raises(RuntimeError, match="load failed"): + await manager.ensure_session_loaded("restored", user_id="test") + + assert manager.sessions == {} + + +@pytest.mark.asyncio +async def test_create_session_rolls_back_placeholder_on_failure( + session_manager_module, monkeypatch +): + manager = session_manager_module.SessionManager() + + def fake_create_session_sync(**kwargs): + raise RuntimeError("boom") + + monkeypatch.setattr(manager, "_create_session_sync", fake_create_session_sync) + + with pytest.raises(RuntimeError, match="boom"): + await manager.create_session(user_id="u1") + + assert manager.sessions == {} + + +@pytest.mark.asyncio +async def test_unload_inactive_session_persists_and_removes(session_manager_module): + manager = session_manager_module.SessionManager() + persisted = [] + + async def fake_persist_session_snapshot(agent_session, **kwargs): + persisted.append((agent_session.session_id, kwargs)) + + manager.persist_session_snapshot = fake_persist_session_snapshot + + stale_session = session_manager_module.AgentSession( + session_id="stale", + session=SimpleNamespace( + pending_approval=None, + notification_destinations=[], + turn_count=0, + config=SimpleNamespace(model_name="gpt"), + ), + tool_router=None, + submission_queue=asyncio.Queue(), + is_active=True, + last_access=datetime.utcnow() + - session_manager_module.INACTIVE_SESSION_IDLE_THRESHOLD + - timedelta(seconds=1), + ) + manager.sessions["stale"] = stale_session + + await manager._unload_inactive_sessions_once() + + assert "stale" not in manager.sessions + assert persisted == [ + ( + "stale", + {"runtime_state": "idle", "status": "inactive"}, + ) + ] + + +@pytest.mark.asyncio +async def test_unload_inactive_sessions_skips_active_and_processing( + session_manager_module, +): + manager = session_manager_module.SessionManager() + persisted = [] + + async def fake_persist_session_snapshot(agent_session, **kwargs): + persisted.append(agent_session.session_id) + + manager.persist_session_snapshot = fake_persist_session_snapshot + + active_session = session_manager_module.AgentSession( + session_id="active", + session=SimpleNamespace( + pending_approval=None, + notification_destinations=[], + turn_count=0, + config=SimpleNamespace(model_name="gpt"), + ), + tool_router=None, + submission_queue=asyncio.Queue(), + is_active=True, + last_access=datetime.utcnow() + - session_manager_module.INACTIVE_SESSION_IDLE_THRESHOLD + - timedelta(seconds=1), + ) + processing_session = session_manager_module.AgentSession( + session_id="processing", + session=SimpleNamespace( + pending_approval=None, + notification_destinations=[], + turn_count=0, + config=SimpleNamespace(model_name="gpt"), + ), + tool_router=None, + submission_queue=asyncio.Queue(), + is_active=True, + is_processing=True, + last_access=datetime.utcnow() + - session_manager_module.INACTIVE_SESSION_IDLE_THRESHOLD + - timedelta(seconds=1), + ) + manager.sessions["active"] = active_session + manager.sessions["processing"] = processing_session + + await manager._unload_inactive_sessions_once() + + assert "active" in manager.sessions + assert "processing" in manager.sessions + assert persisted == []