Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions debug_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import asyncio
from fastapi.testclient import TestClient
from unittest.mock import AsyncMock, patch

from src.api.app import create_app

app = create_app()
client = TestClient(app)

with patch("src.api.routes.memory.require_api_key", return_value={"username": "test_user"}):
from src.api.dependencies import require_api_key, enforce_rate_limit, require_ready
app.dependency_overrides[require_api_key] = lambda: {"username": "test_user"}
app.dependency_overrides[enforce_rate_limit] = lambda: True
app.dependency_overrides[require_ready] = lambda: True

payload = {
"items": [
{
"user_query": "Hello world",
"agent_response": "Hi there",
"user_id": "test_user_1",
}
]
}

try:
response = client.post(
"/v1/memory/batch-ingest",
json=payload,
headers={"Authorization": "Bearer test-key"}
)
print("Status code:", response.status_code)
import json
print(json.dumps(response.json(), indent=2))
except Exception as e:
print("Exception:", e)
Comment on lines +1 to +36
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Unrelated/debug files committed to the repository

debug_test.py, test_output.txt (a binary artifact), and xlsx.py (an Excel workbook generator with no relation to this feature) were accidentally included in the PR. All three should be removed before merging — debug_test.py is a throwaway debug script, test_output.txt is a build artifact, and xlsx.py appears to be a personal utility script that does not belong in this codebase.

Fix in Cursor Fix in Codex Fix in Claude Code

30 changes: 17 additions & 13 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

from src.pipelines.ingest import IngestPipeline
from src.pipelines.retrieval import RetrievalPipeline
from src.api.ingestion_coordinator import UserIngestionCoordinator


# ═══════════════════════════════════════════════════════════════════
Expand Down Expand Up @@ -82,6 +83,7 @@ def emit(self, record: logging.LogRecord) -> None:
_pipelines_ready = asyncio.Event()
_init_error: str | None = None
SKIP_PIPELINES = os.getenv("XMEM_SKIP_PIPELINES", "").lower() in {"1", "true", "yes"}
_user_coordinator = UserIngestionCoordinator()


def _init_pipelines_sync() -> None:
Expand Down Expand Up @@ -315,14 +317,15 @@ async def v1_ingest_memory(req: IngestRequest):
)

try:
result = await ingest_pipeline.run(
user_query=req.user_query,
agent_response=req.agent_response or "Acknowledged.",
user_id=req.user_id,
session_datetime=req.session_datetime,
image_url=req.image_url,
effort_level=req.effort_level,
)
async with _user_coordinator.acquire(req.user_id):
result = await ingest_pipeline.run(
user_query=req.user_query,
agent_response=req.agent_response or "Acknowledged.",
user_id=req.user_id,
session_datetime=req.session_datetime,
image_url=req.image_url,
effort_level=req.effort_level,
)

data = {
"model": _get_model_name(ingest_pipeline.model),
Expand Down Expand Up @@ -368,11 +371,12 @@ async def api_ingest(req: IngestRequest):
lg.addHandler(capture)

try:
result = await ingest_pipeline.run(
user_query=req.user_query,
agent_response=req.agent_response or "Acknowledged.",
user_id=req.user_id,
)
async with _user_coordinator.acquire(req.user_id):
result = await ingest_pipeline.run(
user_query=req.user_query,
agent_response=req.agent_response or "Acknowledged.",
user_id=req.user_id,
)

# Build structured response
response: Dict[str, Any] = {
Expand Down
94 changes: 94 additions & 0 deletions src/api/ingestion_coordinator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
Per-user ingestion coordinator — serialises ingestion for each user.

Guarantees that only one ingestion pipeline runs at a time for any given
``user_id``, while allowing different users to proceed in parallel.
Requests for the same user are processed in strict FIFO order.

This is the **in-memory** implementation (Option 1). A future distributed
lock (Redis, etc.) can be swapped in by implementing the same ``acquire()``
context-manager interface.

Usage::

from src.api.ingestion_coordinator import UserIngestionCoordinator

coordinator = UserIngestionCoordinator()

async with coordinator.acquire(user_id):
result = await pipeline.run(...)
"""

from __future__ import annotations

import asyncio
import logging
from contextlib import asynccontextmanager
from typing import AsyncIterator, Dict

logger = logging.getLogger("xmem.api.ingestion_coordinator")


class UserIngestionCoordinator:
"""Per-user FIFO ingestion lock.

Internally maintains a ``dict[str, asyncio.Lock]`` keyed by ``user_id``.
Locks are created lazily on first access and removed once no tasks are
waiting or holding them, preventing unbounded memory growth.

Thread-safety note
------------------
All mutations to the internal registry are protected by a single
``asyncio.Lock`` (the *registry lock*). Since this code runs on the
asyncio event loop, ``asyncio.Lock`` is sufficient — no OS-level
threading primitives are needed.
"""

def __init__(self) -> None:
# Maps user_id -> (asyncio.Lock, active_count)
# active_count tracks how many tasks are either holding or waiting
# for the lock so we know when it's safe to clean up.
self._locks: Dict[str, asyncio.Lock] = {}
self._waiters: Dict[str, int] = {}
self._registry_lock = asyncio.Lock()

@asynccontextmanager
async def acquire(self, user_id: str) -> AsyncIterator[None]:
"""Acquire the per-user ingestion lock.

Usage::

async with coordinator.acquire("user_123"):
# Only one coroutine per user_id reaches here at a time.
await do_work()

The lock is automatically released (and cleaned up if idle) when
the ``async with`` block exits, even if an exception is raised.
"""
# ── Get-or-create the user lock ──────────────────────────────
async with self._registry_lock:
if user_id not in self._locks:
self._locks[user_id] = asyncio.Lock()
self._waiters[user_id] = 0
self._waiters[user_id] += 1
user_lock = self._locks[user_id]

logger.debug("User %s: waiting for ingestion lock (waiters=%d)", user_id, self._waiters.get(user_id, 0))

try:
async with user_lock:
logger.debug("User %s: ingestion lock acquired", user_id)
yield
finally:
# ── Cleanup: remove the lock if nobody else is waiting ────
async with self._registry_lock:
self._waiters[user_id] -= 1
if self._waiters[user_id] <= 0:
self._locks.pop(user_id, None)
self._waiters.pop(user_id, None)
logger.debug("User %s: ingestion lock cleaned up", user_id)

@property
def active_users(self) -> int:
"""Return the number of users with active or pending ingestion locks."""
return len(self._locks)
24 changes: 14 additions & 10 deletions src/api/routes/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
WeaverSummary,
)
from src.pipelines.retrieval import RetrievalPipeline
from src.api.ingestion_coordinator import UserIngestionCoordinator

from bs4 import BeautifulSoup
import json
Expand All @@ -58,6 +59,7 @@
logger = logging.getLogger("xmem.api.routes.memory")

_ingest_semaphore = asyncio.Semaphore(5)
_user_coordinator = UserIngestionCoordinator()

router = APIRouter(
prefix="/v1/memory",
Expand Down Expand Up @@ -681,10 +683,11 @@ async def ingest_memory(req: IngestRequest, request: Request, user: dict = Depen
payload = req.model_dump()

try:
data = await asyncio.wait_for(
_run_ingest_payload(payload, user_id),
timeout=120.0,
)
async with _user_coordinator.acquire(user_id):
data = await asyncio.wait_for(
_run_ingest_payload(payload, user_id),
timeout=120.0,
)
elapsed = round((time.perf_counter() - start) * 1000, 2)
return _wrap(request, data, elapsed)

Expand Down Expand Up @@ -801,12 +804,13 @@ async def batch_ingest_memory(req: BatchIngestRequest, request: Request, user: d

try:
results = []
for item in req.items:
data = await asyncio.wait_for(
_run_ingest_payload(item.model_dump(), user_id),
timeout=120.0,
)
results.append(IngestResponse(**data))
async with _user_coordinator.acquire(user_id):
for item in req.items:
data = await asyncio.wait_for(
_run_ingest_payload(item.model_dump(), user_id),
timeout=120.0,
)
results.append(IngestResponse(**data))

response_data = BatchIngestResponse(results=results)
elapsed = round((time.perf_counter() - start) * 1000, 2)
Expand Down
Binary file added test_output.txt
Binary file not shown.
135 changes: 135 additions & 0 deletions tests/test_batch_ingest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import pytest
from fastapi.testclient import TestClient
from unittest.mock import AsyncMock, patch
from typing import Dict, Any

from src.api.app import create_app
from src.api.schemas import BatchIngestRequest, IngestRequest
from src.pipelines.ingest import IngestPipeline

@pytest.fixture
def client():
app = create_app()
return TestClient(app)

@pytest.fixture
def mock_ingest_pipeline():
with patch("src.api.routes.memory.get_ingest_pipeline") as mock_get_pipeline:
from types import SimpleNamespace
mock_pipeline = AsyncMock(spec=IngestPipeline)
mock_pipeline.model = SimpleNamespace(model_name="test-model")

# Default mock behavior
async def mock_run(*args, **kwargs):
return {
"classification_result": SimpleNamespace(classifications=["test"]),
"profile_judge": None,
"profile_weaver": None,
"temporal_judge": None,
"temporal_weaver": None,
"summary_judge": None,
"summary_weaver": None,
"image_judge": None,
"image_weaver": None,
}

mock_pipeline.run.side_effect = mock_run
mock_get_pipeline.return_value = mock_pipeline
yield mock_pipeline

def test_batch_ingest_success(client, mock_ingest_pipeline):
"""Test that multiple items can be successfully ingested in a batch."""
payload = {
"items": [
{
"user_query": "Hello world",
"agent_response": "Hi there",
"user_id": "test_user_1",
},
{
"user_query": "Second message",
"agent_response": "Understood",
"user_id": "test_user_1",
}
]
}

# You must provide API key or mock dependency for require_api_key
# For test purposes, we assume we override the dependency or add a test key
# Let's mock require_api_key in dependencies
with patch("src.api.routes.memory.require_api_key", return_value={"username": "test_user"}):
app = client.app
from src.api.dependencies import require_api_key, enforce_rate_limit, require_ready
app.dependency_overrides[require_api_key] = lambda: {"username": "test_user"}
app.dependency_overrides[enforce_rate_limit] = lambda: True
app.dependency_overrides[require_ready] = lambda: True

response = client.post(
"/v1/memory/batch-ingest",
json=payload,
headers={"Authorization": "Bearer test-key"}
)

assert response.status_code == 200, response.json()
data = response.json()
assert data["status"] == "ok", data
assert len(data["data"]["results"]) == 2, data
for item in data["data"]["results"]:
assert item["model"] == "test-model"


def test_coordinator_serializes_concurrent_batches(client, mock_ingest_pipeline):
"""Two concurrent batch-ingest requests for the same user must not overlap.

We verify this by checking that all 4 pipeline.run calls were made
(2 items × 2 batches) and both requests succeed.
"""
import threading

payload = {
"items": [
{
"user_query": "Batch message 1",
"agent_response": "Ack 1",
"user_id": "same_user",
},
{
"user_query": "Batch message 2",
"agent_response": "Ack 2",
"user_id": "same_user",
},
]
}

with patch("src.api.routes.memory.require_api_key", return_value={"username": "same_user"}):
app = client.app
from src.api.dependencies import require_api_key, enforce_rate_limit, require_ready
app.dependency_overrides[require_api_key] = lambda: {"username": "same_user"}
app.dependency_overrides[enforce_rate_limit] = lambda: True
app.dependency_overrides[require_ready] = lambda: True

# Send two batch requests concurrently via threads
results = [None, None]

def _send_batch(idx):
results[idx] = client.post(
"/v1/memory/batch-ingest",
json=payload,
headers={"Authorization": "Bearer test-key"},
)

t1 = threading.Thread(target=_send_batch, args=(0,))
t2 = threading.Thread(target=_send_batch, args=(1,))
t1.start()
t2.start()
t1.join()
t2.join()

# Both requests should succeed
for r in results:
assert r is not None
assert r.status_code == 200, r.json()

# All 4 pipeline.run calls (2 items × 2 batches) should have been made
assert mock_ingest_pipeline.run.call_count == 4

Loading
Loading