Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 118 additions & 10 deletions backend/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Optional

Expand Down Expand Up @@ -100,6 +100,8 @@ class AgentSession:
is_processing: bool = False # True while a submission is being executed
broadcaster: Any = None
title: str | None = None
# Last time this session was accessed (for idle-unload logic)
last_access: datetime = field(default_factory=datetime.utcnow)
# True once this session has been counted against the user's daily
# Claude quota. Guards double-counting when the user re-selects an
# Anthropic model mid-session.
Expand Down Expand Up @@ -141,10 +143,19 @@ async def start(self) -> None:
self.persistence_store = get_session_store()
await self.persistence_store.init()
await self.messaging_gateway.start()
# Start background cleanup task to unload long-idle sessions.
self._cleanup_task = asyncio.create_task(self._unload_inactive_sessions_loop())

async def close(self) -> None:
"""Flush and close shared background resources."""
await self._cleanup_all_sandboxes_on_close()
# Cancel cleanup task
if getattr(self, "_cleanup_task", None):
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
Comment thread
OmShrivastava19 marked this conversation as resolved.
await self.messaging_gateway.close()
if self.persistence_store is not None:
await self.persistence_store.close()
Expand Down Expand Up @@ -327,8 +338,17 @@ async def _start_agent_session(
async with self._lock:
existing = self.sessions.get(agent_session.session_id)
if existing:
return existing
self.sessions[agent_session.session_id] = agent_session
# If an earlier coroutine reserved the slot with a placeholder
# AgentSession (session is None), replace it with the real
# agent_session. Otherwise return the existing live session.
if getattr(existing, "session", None) is None:
self.sessions[agent_session.session_id] = agent_session
else:
# Update access time for the existing live session
existing.last_access = datetime.utcnow()
return existing
else:
self.sessions[agent_session.session_id] = agent_session

task = asyncio.create_task(
self._run_session(
Expand All @@ -339,6 +359,8 @@ async def _start_agent_session(
)
)
agent_session.task = task
# mark last access time when the session has been started
agent_session.last_access = datetime.utcnow()
return agent_session

@staticmethod
Expand Down Expand Up @@ -494,6 +516,61 @@ async def _cleanup_persisted_sandbox(
last_err,
)

async def _unload_inactive_sessions_loop(self) -> None:
"""Background task: unload sessions that have been idle for too long.

Sessions idle for more than 24 hours are persisted and removed from
memory to free capacity. The loop runs periodically (every 10 minutes).
"""
Comment thread
OmShrivastava19 marked this conversation as resolved.
try:
while True:
await asyncio.sleep(600) # 10 minutes
now = datetime.utcnow()
Comment thread
OmShrivastava19 marked this conversation as resolved.
Outdated
to_unload: list[tuple[str, AgentSession]] = []
async with self._lock:
for sid, agent_session in list(self.sessions.items()):
try:
if not agent_session.is_active:
continue
if getattr(agent_session, "is_processing", False):
continue
last = getattr(agent_session, "last_access", agent_session.created_at)
if now - last > timedelta(hours=24):
# Mark inactive and remove from registry under lock
agent_session.is_active = False
to_unload.append((sid, agent_session))
del self.sessions[sid]
except Exception as e:
logger.debug("Skipping unload check for %s: %s", sid, e)
Comment thread
OmShrivastava19 marked this conversation as resolved.
Outdated

for sid, agent_session in to_unload:
try:
await self.persist_session_snapshot(
agent_session,
runtime_state=self._runtime_state(agent_session),
status="inactive",
)
Comment thread
OmShrivastava19 marked this conversation as resolved.
Outdated
except Exception as e:
logger.warning(
"Failed to persist snapshot before unloading %s: %s",
sid,
e,
)
# Cancel the running task if present
try:
if agent_session.task:
agent_session.task.cancel()
try:
await asyncio.wait_for(agent_session.task, timeout=5)
except Exception:
pass
except Exception:
pass

logger.info("Unloaded inactive session %s due to inactivity", sid)
except asyncio.CancelledError:
return

async def persist_session_snapshot(
self,
agent_session: AgentSession,
Expand Down Expand Up @@ -567,9 +644,22 @@ async def ensure_session_loaded(
existing,
preload_sandbox=preload_sandbox,
)
existing.last_access = datetime.utcnow()
return existing
return None

# Check capacity before restoring from persistence
async with self._lock:
active_count = self.active_session_count
if active_count >= MAX_SESSIONS:
logger.warning(
"Cannot restore session %s: server at capacity (%d/%d)",
session_id,
active_count,
MAX_SESSIONS,
)
return None

store = self._store()
loaded = await store.load_session(session_id)
Comment thread
OmShrivastava19 marked this conversation as resolved.
Outdated
if not loaded:
Expand All @@ -588,6 +678,7 @@ async def ensure_session_loaded(
existing,
preload_sandbox=preload_sandbox,
)
existing.last_access = datetime.utcnow()
return existing
return None

Expand Down Expand Up @@ -674,6 +765,7 @@ async def ensure_session_loaded(
hf_token=hf_token,
hf_username=hf_username,
)
started.last_access = datetime.utcnow()
return started
if preload_sandbox:
self._start_cpu_sandbox_preload(agent_session)
Expand Down Expand Up @@ -706,7 +798,11 @@ async def create_session(
SessionCapacityError: If the server or user has reached the
maximum number of concurrent sessions.
"""
# ── Capacity checks ──────────────────────────────────────────
# ── Capacity checks & reservation ───────────────────────────
# Create lightweight queues up-front (non-blocking).
submission_queue: asyncio.Queue = asyncio.Queue()
event_queue: asyncio.Queue = asyncio.Queue()

async with self._lock:
active_count = self.active_session_count
if active_count >= MAX_SESSIONS:
Expand All @@ -724,11 +820,22 @@ async def create_session(
error_type="per_user",
)

session_id = str(uuid.uuid4())

# Create queues for this session
submission_queue: asyncio.Queue = asyncio.Queue()
event_queue: asyncio.Queue = asyncio.Queue()
session_id = str(uuid.uuid4())
# Reserve the slot with a placeholder AgentSession so concurrent
# creators cannot exceed MAX_SESSIONS. The placeholder has no
# real Session/tool_router yet (session is None) but counts
# towards `active_session_count` because `is_active` is True.
placeholder = AgentSession(
session_id=session_id,
session=None,
tool_router=None,
submission_queue=submission_queue,
user_id=user_id,
hf_username=hf_username,
hf_token=hf_token,
is_active=True,
)
self.sessions[session_id] = placeholder
Comment thread
OmShrivastava19 marked this conversation as resolved.

# Run blocking constructors in a thread to keep the event loop responsive.
tool_router, session = await asyncio.to_thread(
Expand All @@ -741,7 +848,8 @@ async def create_session(
event_queue=event_queue,
)

# Create wrapper
# Create wrapper with the real session resources and replace the
# placeholder in _start_agent_session.
agent_session = AgentSession(
session_id=session_id,
session=session,
Expand Down
90 changes: 90 additions & 0 deletions tests/unit/test_session_capacity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import asyncio
import sys
from pathlib import Path
from types import SimpleNamespace

import pytest
import types

# Prevent importing heavy third-party modules when importing the backend module.
for _mod in ("litellm", "fastmcp", "thefuzz", "huggingface_hub"):
if _mod not in sys.modules:
m = types.ModuleType(_mod)
# fastmcp is imported as `from fastmcp import Client` in some codepaths
if _mod == "fastmcp":
class _DummyClient:
pass

setattr(m, "Client", _DummyClient)
sys.modules[_mod] = m
Comment thread
OmShrivastava19 marked this conversation as resolved.
Outdated
Comment thread
OmShrivastava19 marked this conversation as resolved.
Outdated

_BACKEND_DIR = Path(__file__).resolve().parent.parent.parent / "backend"
if str(_BACKEND_DIR) not in sys.path:
sys.path.insert(0, str(_BACKEND_DIR))

from session_manager import SessionManager, AgentSession, MAX_SESSIONS


@pytest.mark.asyncio
async def test_restore_denied_when_at_capacity(caplog):
manager = SessionManager()
# Fill in-memory sessions up to MAX_SESSIONS
for i in range(MAX_SESSIONS):
manager.sessions[str(i)] = AgentSession(
session_id=str(i),
session=object(),
tool_router=None,
submission_queue=asyncio.Queue(),
is_active=True,
)

class DummyStore:
enabled = True

async def load_session(self, sid):
return {"metadata": {"user_id": "test", "model": "gpt"}, "messages": []}

manager.persistence_store = DummyStore()

caplog.set_level("WARNING")
res = await manager.ensure_session_loaded("restored", user_id="test")
assert res is None
assert any("Cannot restore session" in rec.message for rec in caplog.records)


@pytest.mark.asyncio
async def test_restore_allowed_under_capacity(monkeypatch):
manager = SessionManager()
manager.sessions.clear()

class DummyStore:
enabled = True

async def load_session(self, sid):
return {"metadata": {"user_id": "test", "model": "gpt"}, "messages": []}

manager.persistence_store = DummyStore()

# Replace the heavy sync constructor with a lightweight fake
def fake_create_session_sync(*, session_id, user_id, hf_username, hf_token, model, event_queue, notification_destinations):
class SimpleSession:
def __init__(self):
self.context_manager = SimpleNamespace(items=[object()])
self.pending_approval = None
self.turn_count = 0

return None, SimpleSession()

monkeypatch.setattr(manager, "_create_session_sync", fake_create_session_sync)

# Fake _start_agent_session to register the session without starting tasks
async def fake_start_agent_session(*, agent_session, event_queue, tool_router):
async with manager._lock:
manager.sessions[agent_session.session_id] = agent_session
return agent_session

monkeypatch.setattr(manager, "_start_agent_session", fake_start_agent_session)

res = await manager.ensure_session_loaded("restored", user_id="test")
assert res is not None
assert res.session is not None
Loading