Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 143 additions & 16 deletions src/api/routes/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
import asyncio
import logging
import time
from typing import Any, Dict, List
from collections import defaultdict, deque
from typing import Any, Callable, Dict, List

from fastapi import APIRouter, Depends, Request, UploadFile, File
from fastapi.responses import JSONResponse
from langchain_core.messages import HumanMessage

from src.api.dependencies import (
enforce_rate_limit,
get_code_pipeline,
get_ingest_pipeline,
get_retrieval_pipeline,
require_api_key,
Expand All @@ -41,6 +44,7 @@
WeaverSummary,
)
from src.pipelines.retrieval import RetrievalPipeline
from src.prompts.retrieval import ANSWER_PROMPT

from bs4 import BeautifulSoup
import json
Expand All @@ -63,6 +67,26 @@
dependencies=[Depends(enforce_rate_limit)],
)

_SEARCH_LATENCIES: Dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=500))


def _record_search_latency(mode: str, elapsed_ms: float) -> Dict[str, float]:
samples = _SEARCH_LATENCIES[mode]
samples.append(elapsed_ms)
ordered = sorted(samples)

def pct(percent: float) -> float:
if not ordered:
return 0.0
idx = min(len(ordered) - 1, int(round((len(ordered) - 1) * percent)))
return round(ordered[idx], 2)
Comment on lines +74 to +82
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 _record_search_latency calls len(ordered) repeatedly inside the nested pct closure. Capturing it once as n avoids redundant calls and makes the guard condition symmetric.

Suggested change
samples = _SEARCH_LATENCIES[mode]
samples.append(elapsed_ms)
ordered = sorted(samples)
def pct(percent: float) -> float:
if not ordered:
return 0.0
idx = min(len(ordered) - 1, int(round((len(ordered) - 1) * percent)))
return round(ordered[idx], 2)
samples = _SEARCH_LATENCIES[mode]
samples.append(elapsed_ms)
ordered = sorted(samples)
n = len(ordered)
def pct(percent: float) -> float:
if not n:
return 0.0
idx = min(n - 1, int(round((n - 1) * percent)))
return round(ordered[idx], 2)

Fix in Cursor Fix in Codex Fix in Claude Code


return {
f"{mode}_p50": pct(0.50),
f"{mode}_p95": pct(0.95),
f"{mode}_p99": pct(0.99),
}


# Helpers
def _model_name(model: Any) -> str:
Expand Down Expand Up @@ -679,27 +703,54 @@ async def retrieve_memory(req: RetrieveRequest, request: Request, user: dict = D
@router.post(
"/search",
response_model=APIResponse,
summary="Raw semantic search across memory domains (no LLM answer)",
summary="Fast raw search across memory domains, with optional LLM answer synthesis",
)
async def search_memory(req: SearchRequest, request: Request, user: dict = Depends(require_api_key)):
start = time.perf_counter()
pipeline = get_retrieval_pipeline()

# Get username from authenticated user

user_id = user.get("username") or user.get("name") or user["id"]
mode = "answer" if req.answer else "raw"

try:
timings: Dict[str, float] = {}
all_results: List[SourceRecord] = []
domains = list(dict.fromkeys(req.domains))

if "profile" in domains:
all_results.extend(await _timed_sync("profile", timings, _search_profile, pipeline, user_id))
if "temporal" in domains:
all_results.extend(await _timed_sync("temporal", timings, _search_temporal, pipeline, req.query, user_id, req.top_k))
if "summary" in domains:
all_results.extend(await _timed_async("summary", timings, _search_summary, pipeline, req.query, user_id, req.top_k))
if "snippet" in domains:
all_results.extend(await _timed_async("snippet", timings, _search_snippet, pipeline, req.query, user_id, req.top_k))
if "code" in domains:
if not req.org_id or not req.repo:
return _error(request, "org_id and repo are required when domains includes 'code'.", 400, 0)
all_results.extend(await _timed_async("code", timings, _search_code, req, user_id))
Comment on lines +720 to +731
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The domain searches are currently executed sequentially. To improve performance for this 'fast search' endpoint, these searches should be run concurrently using asyncio.gather.

Suggested change
if "profile" in domains:
all_results.extend(await _timed_sync("profile", timings, _search_profile, pipeline, user_id))
if "temporal" in domains:
all_results.extend(await _timed_sync("temporal", timings, _search_temporal, pipeline, req.query, user_id, req.top_k))
if "summary" in domains:
all_results.extend(await _timed_async("summary", timings, _search_summary, pipeline, req.query, user_id, req.top_k))
if "snippet" in domains:
all_results.extend(await _timed_async("snippet", timings, _search_snippet, pipeline, req.query, user_id, req.top_k))
if "code" in domains:
if not req.org_id or not req.repo:
return _error(request, "org_id and repo are required when domains includes 'code'.", 400, 0)
all_results.extend(await _timed_async("code", timings, _search_code, req, user_id))
tasks = []
if "profile" in domains:
tasks.append(_timed_sync("profile", timings, _search_profile, pipeline, user_id))
if "temporal" in domains:
tasks.append(_timed_sync("temporal", timings, _search_temporal, pipeline, req.query, user_id, req.top_k))
if "summary" in domains:
tasks.append(_timed_async("summary", timings, _search_summary, pipeline, req.query, user_id, req.top_k))
if "snippet" in domains:
tasks.append(_timed_async("snippet", timings, _search_snippet, pipeline, req.query, user_id, req.top_k))
if "code" in domains:
if not req.org_id or not req.repo:
return _error(request, "org_id and repo are required when domains includes 'code'.", 400, 0)
tasks.append(_timed_async("code", timings, _search_code, req, user_id))
results_from_domains = await asyncio.gather(*tasks)
for domain_results in results_from_domains:
all_results.extend(domain_results)

Comment on lines +720 to +731
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Domain searches execute sequentially, not concurrently

Each domain is awaited one-at-a-time in a serial if/await chain. For a request that includes all four domains, the total latency is the sum of all domain round-trips rather than their maximum. The PR description frames this as a "fast search path," but searching profile → temporal → summary → snippet → code in sequence cancels most of the speed benefit. Using asyncio.gather() (with per-domain error isolation) would make the total latency roughly equal to the slowest single domain instead of their sum.

Fix in Cursor Fix in Codex Fix in Claude Code


all_results.sort(key=lambda r: r.score, reverse=True)
answer_text = None
confidence = min(1.0, len(all_results) * 0.2) if all_results else 0.0
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 confidence score is based purely on result count, not relevance

confidence = min(1.0, len(all_results) * 0.2) reaches 1.0 once there are 5 or more results, regardless of their scores. A search that returns 5 results all with score=0.01 will report the same 100% confidence as one with 5 results at score=0.99. Callers relying on this field for quality-gating could be misled. Consider basing confidence on the actual score distribution (e.g., mean or max score of the top results) rather than count alone.

Fix in Cursor Fix in Codex Fix in Claude Code


if req.answer:
answer_start = time.perf_counter()
answer_text = await _synthesize_search_answer(pipeline, req.query, all_results)
timings["answer"] = round((time.perf_counter() - answer_start) * 1000, 2)

if "profile" in req.domains:
all_results.extend(_search_profile(pipeline, user_id))
if "temporal" in req.domains:
all_results.extend(_search_temporal(pipeline, req.query, user_id, req.top_k))
if "summary" in req.domains:
all_results.extend(await _search_summary(pipeline, req.query, user_id, req.top_k))

data = SearchResponse(results=all_results, total=len(all_results))
elapsed = round((time.perf_counter() - start) * 1000, 2)
timings[f"{mode}_total"] = elapsed
timings.update(_record_search_latency(mode, elapsed))

data = SearchResponse(
results=all_results,
total=len(all_results),
mode=mode,
answer=answer_text,
confidence=confidence,
latency_ms=timings,
)
return _wrap(request, data, elapsed)

except Exception as exc:
Expand All @@ -708,12 +759,28 @@ async def search_memory(req: SearchRequest, request: Request, user: dict = Depen
return _error(request, str(exc), 500, elapsed)


async def _timed_sync(label: str, timings: Dict[str, float], fn: Callable, *args) -> List[SourceRecord]:
start = time.perf_counter()
records = await asyncio.to_thread(fn, *args)
timings[label] = round((time.perf_counter() - start) * 1000, 2)
return records


async def _timed_async(label: str, timings: Dict[str, float], fn: Callable, *args) -> List[SourceRecord]:
start = time.perf_counter()
records = await fn(*args)
timings[label] = round((time.perf_counter() - start) * 1000, 2)
return records


def _search_profile(pipeline: RetrievalPipeline, user_id: str) -> List[SourceRecord]:
try:
raw = pipeline.vector_store.search_by_metadata(
filters={"user_id": user_id, "domain": "profile"}, top_k=100,
)
return [SourceRecord(domain="profile", content=r.content, score=r.score, metadata=r.metadata) for r in raw]
_, raw = pipeline._fetch_profile_catalog(user_id)
pipeline._cached_profile_records = raw
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line appears to be unnecessary. The _cached_profile_records attribute is not used within this file's logic, and setting it here creates a potentially confusing side effect on the pipeline object. Please consider removing it.

return [
SourceRecord(domain="profile", content=r.content, score=r.score, metadata={"id": r.id, **r.metadata})
for r in raw
]
except Exception as exc:
logger.warning("Profile search error: %s", exc)
return []
Expand Down Expand Up @@ -763,6 +830,66 @@ async def _search_summary(pipeline: RetrievalPipeline, query: str, user_id: str,
return []


async def _search_snippet(pipeline: RetrievalPipeline, query: str, user_id: str, top_k: int) -> List[SourceRecord]:
try:
return await pipeline._search_snippet(query=query, user_id=user_id, top_k=top_k)
except Exception as exc:
logger.warning("Snippet search error: %s", exc)
return []


async def _search_code(req: SearchRequest, user_id: str) -> List[SourceRecord]:
try:
from src.api.routes.scanner import _get_code_store, _scanner_chat_allowed

denied = _scanner_chat_allowed(_get_code_store(), user_id, req.org_id or "", req.repo or "")
if denied:
logger.warning("Code search denied for user=%s org=%s repo=%s: %s", user_id, req.org_id, req.repo, denied)
return []

pipeline = get_code_pipeline(org_id=req.org_id or "", repo=req.repo or "", project_id=req.project_id)
symbol_results, file_results = await asyncio.gather(
pipeline._execute_tool(
tool_name="SearchSymbols",
tool_args={"query": req.query, "repo": req.repo or ""},
repo=req.repo or "",
top_k=req.top_k,
user_id=user_id,
),
pipeline._execute_tool(
tool_name="SearchFiles",
tool_args={"query": req.query, "repo": req.repo or ""},
repo=req.repo or "",
top_k=req.top_k,
user_id=user_id,
),
)
return [SourceRecord(domain="code", content=r.content, score=r.score, metadata=r.metadata) for r in [*symbol_results, *file_results]]
except Exception as exc:
logger.warning("Code search error: %s", exc)
return []


async def _synthesize_search_answer(
pipeline: RetrievalPipeline, query: str, sources: List[SourceRecord]
) -> str:
context = "\n".join(
f"[{idx + 1}] domain={source.domain} score={source.score:.3f}\n{source.content}"
for idx, source in enumerate(sources)
) or "(No memory search results found.)"
response = await pipeline.model.ainvoke([HumanMessage(content=ANSWER_PROMPT.format(context=context, query=query))])
content = response.content if hasattr(response, "content") else response
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict) and "text" in item:
parts.append(item["text"])
else:
parts.append(str(item))
return "\n".join(parts)
return str(content)


# POST /v1/memory/scrape
@scrape_router.post(
"/scrape",
Expand Down
32 changes: 29 additions & 3 deletions src/api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Literal, Optional

from pydantic import BaseModel, Field, field_validator

Expand Down Expand Up @@ -159,15 +159,37 @@ class SearchRequest(BaseModel):
..., min_length=1, max_length=256, pattern=r"^[\w.\-@]+$",
)
domains: List[str] = Field(
default=["profile", "temporal", "summary"],
default=["profile", "temporal", "summary", "snippet"],
description="Which memory domains to search",
)
top_k: int = Field(default=10, ge=1, le=100)
answer: bool = Field(
default=False,
description="When true, also run LLM answer synthesis using the ranked search results.",
)
org_id: Optional[str] = Field(
default=None,
min_length=1,
max_length=256,
description="Organization ID for code domain searches.",
)
repo: Optional[str] = Field(
default=None,
min_length=1,
max_length=256,
description="Repository name for code domain searches.",
)
project_id: Optional[str] = Field(
default=None,
min_length=1,
max_length=256,
description="Project ID for annotation-aware code searches.",
)

@field_validator("domains")
@classmethod
def validate_domains(cls, v: List[str]) -> List[str]:
allowed = {"profile", "temporal", "summary"}
allowed = {"profile", "temporal", "summary", "snippet", "code"}
for d in v:
if d not in allowed:
raise ValueError(f"Invalid domain '{d}'. Allowed: {allowed}")
Expand All @@ -177,6 +199,10 @@ def validate_domains(cls, v: List[str]) -> List[str]:
class SearchResponse(BaseModel):
results: List[SourceRecord] = Field(default_factory=list)
total: int = 0
mode: Literal["raw", "answer"] = "raw"
answer: Optional[str] = None
confidence: float = 0.0
latency_ms: Dict[str, float] = Field(default_factory=dict)


# ── Scrape (extract from shared chat links) ────────────────────────────────
Expand Down
24 changes: 21 additions & 3 deletions src/pipelines/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import asyncio
import logging
import time
from typing import Any, Callable, Dict, List, Optional

from dotenv import load_dotenv
Expand Down Expand Up @@ -133,6 +134,10 @@ def __init__(

self.embed_fn = embed_fn
self._snippet_stores: Dict[str, BaseVectorStore] = {}
self._profile_catalog_cache: Dict[str, tuple[float, List[Dict[str, str]], List[Any]]] = {}
self._profile_catalog_ttl_seconds = 60.0
self._retrieval_plan_cache: Dict[tuple[str, str, str, int], tuple[float, AIMessage]] = {}
self._retrieval_plan_ttl_seconds = 30.0

logger.info("RetrievalPipeline initialized")

Expand Down Expand Up @@ -169,8 +174,15 @@ async def run(
HumanMessage(content=query),
]

ai_response: AIMessage = await self.model_with_tools.ainvoke(messages)
logger.info("LLM response received (tool_calls=%d)", len(ai_response.tool_calls or []))
plan_cache_key = (user_id, query, catalog_text, top_k)
cached_plan = self._retrieval_plan_cache.get(plan_cache_key)
if cached_plan and time.monotonic() - cached_plan[0] < self._retrieval_plan_ttl_seconds:
ai_response = cached_plan[1]
logger.info("LLM retrieval plan cache hit (tool_calls=%d)", len(ai_response.tool_calls or []))
else:
ai_response: AIMessage = await self.model_with_tools.ainvoke(messages)
self._retrieval_plan_cache[plan_cache_key] = (time.monotonic(), ai_response)
Comment on lines 139 to +184
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Unbounded retrieval plan cache on the singleton pipeline

_retrieval_plan_cache uses (user_id, query, catalog_text, top_k) as a key, where query is a free-text string. Because queries are rarely repeated verbatim within the 30-second TTL window, the cache almost never hits in practice — but every request still appends a new entry. Stale entries are only replaced when an exact key match occurs on a subsequent read; there is no periodic eviction or max-size cap. On the singleton RetrievalPipeline, this means the dict grows indefinitely over the process lifetime, accumulating AIMessage objects for every unique query ever seen. A production system under load will leak memory until the process OOM-restarts. Adding a bounded structure (e.g., a maxsize-capped LRUCache) or an explicit eviction sweep is needed.

Fix in Cursor Fix in Codex Fix in Claude Code

logger.info("LLM response received (tool_calls=%d)", len(ai_response.tool_calls or []))

# ── Step 2: Execute tool calls ────────────────────────────────
sources: List[SourceRecord] = []
Expand Down Expand Up @@ -487,13 +499,18 @@ async def _search_snippet(
# ------------------------------------------------------------------

def _fetch_profile_catalog(self, user_id: str):
"""Fetch all profile entries for a user.
"""Fetch all profile entries for a user, caching the catalog briefly.

Returns:
(catalog, raw_results)
catalog — list of {topic, sub_topic} for the prompt
raw_results — the full SearchResult list, cached for _search_profile
"""
now = time.monotonic()
cached = self._profile_catalog_cache.get(user_id)
if cached and now - cached[0] < self._profile_catalog_ttl_seconds:
return cached[1], cached[2]

try:
results = self.vector_store.search_by_metadata(
filters={"user_id": user_id, "domain": "profile"},
Expand Down Expand Up @@ -524,6 +541,7 @@ def _fetch_profile_catalog(self, user_id: str):
"sub_topic": "",
})

self._profile_catalog_cache[user_id] = (now, catalog, results)
return catalog, results

def _format_catalog(self, catalog: List[Dict[str, str]]) -> str:
Expand Down
Loading
Loading