-
Notifications
You must be signed in to change notification settings - Fork 38
Add fast memory search mode #188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -9,13 +9,16 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import asyncio | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import logging | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import time | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any, Dict, List | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from collections import defaultdict, deque | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any, Callable, Dict, List | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from fastapi import APIRouter, Depends, Request, UploadFile, File | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from fastapi.responses import JSONResponse | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from langchain_core.messages import HumanMessage | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from src.api.dependencies import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| enforce_rate_limit, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| get_code_pipeline, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| get_ingest_pipeline, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| get_retrieval_pipeline, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| require_api_key, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -41,6 +44,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| WeaverSummary, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from src.pipelines.retrieval import RetrievalPipeline | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from src.prompts.retrieval import ANSWER_PROMPT | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from bs4 import BeautifulSoup | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import json | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -63,6 +67,26 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dependencies=[Depends(enforce_rate_limit)], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _SEARCH_LATENCIES: Dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=500)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _record_search_latency(mode: str, elapsed_ms: float) -> Dict[str, float]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| samples = _SEARCH_LATENCIES[mode] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| samples.append(elapsed_ms) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ordered = sorted(samples) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def pct(percent: float) -> float: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not ordered: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return 0.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| idx = min(len(ordered) - 1, int(round((len(ordered) - 1) * percent))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return round(ordered[idx], 2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"{mode}_p50": pct(0.50), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"{mode}_p95": pct(0.95), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"{mode}_p99": pct(0.99), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Helpers | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _model_name(model: Any) -> str: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -679,27 +703,54 @@ async def retrieve_memory(req: RetrieveRequest, request: Request, user: dict = D | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @router.post( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "/search", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| response_model=APIResponse, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| summary="Raw semantic search across memory domains (no LLM answer)", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| summary="Fast raw search across memory domains, with optional LLM answer synthesis", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def search_memory(req: SearchRequest, request: Request, user: dict = Depends(require_api_key)): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| start = time.perf_counter() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pipeline = get_retrieval_pipeline() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Get username from authenticated user | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| user_id = user.get("username") or user.get("name") or user["id"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mode = "answer" if req.answer else "raw" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| timings: Dict[str, float] = {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_results: List[SourceRecord] = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| domains = list(dict.fromkeys(req.domains)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "profile" in domains: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_results.extend(await _timed_sync("profile", timings, _search_profile, pipeline, user_id)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "temporal" in domains: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_results.extend(await _timed_sync("temporal", timings, _search_temporal, pipeline, req.query, user_id, req.top_k)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "summary" in domains: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_results.extend(await _timed_async("summary", timings, _search_summary, pipeline, req.query, user_id, req.top_k)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "snippet" in domains: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_results.extend(await _timed_async("snippet", timings, _search_snippet, pipeline, req.query, user_id, req.top_k)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "code" in domains: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not req.org_id or not req.repo: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return _error(request, "org_id and repo are required when domains includes 'code'.", 400, 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_results.extend(await _timed_async("code", timings, _search_code, req, user_id)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+720
to
+731
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The domain searches are currently executed sequentially. To improve performance for this 'fast search' endpoint, these searches should be run concurrently using
Suggested change
Comment on lines
+720
to
+731
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Each domain is awaited one-at-a-time in a serial |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_results.sort(key=lambda r: r.score, reverse=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| answer_text = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| confidence = min(1.0, len(all_results) * 0.2) if all_results else 0.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if req.answer: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| answer_start = time.perf_counter() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| answer_text = await _synthesize_search_answer(pipeline, req.query, all_results) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| timings["answer"] = round((time.perf_counter() - answer_start) * 1000, 2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "profile" in req.domains: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_results.extend(_search_profile(pipeline, user_id)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "temporal" in req.domains: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_results.extend(_search_temporal(pipeline, req.query, user_id, req.top_k)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "summary" in req.domains: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_results.extend(await _search_summary(pipeline, req.query, user_id, req.top_k)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| data = SearchResponse(results=all_results, total=len(all_results)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elapsed = round((time.perf_counter() - start) * 1000, 2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| timings[f"{mode}_total"] = elapsed | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| timings.update(_record_search_latency(mode, elapsed)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| data = SearchResponse( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| results=all_results, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| total=len(all_results), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mode=mode, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| answer=answer_text, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| confidence=confidence, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| latency_ms=timings, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return _wrap(request, data, elapsed) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as exc: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -708,12 +759,28 @@ async def search_memory(req: SearchRequest, request: Request, user: dict = Depen | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return _error(request, str(exc), 500, elapsed) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def _timed_sync(label: str, timings: Dict[str, float], fn: Callable, *args) -> List[SourceRecord]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| start = time.perf_counter() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| records = await asyncio.to_thread(fn, *args) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| timings[label] = round((time.perf_counter() - start) * 1000, 2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return records | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def _timed_async(label: str, timings: Dict[str, float], fn: Callable, *args) -> List[SourceRecord]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| start = time.perf_counter() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| records = await fn(*args) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| timings[label] = round((time.perf_counter() - start) * 1000, 2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return records | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _search_profile(pipeline: RetrievalPipeline, user_id: str) -> List[SourceRecord]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raw = pipeline.vector_store.search_by_metadata( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| filters={"user_id": user_id, "domain": "profile"}, top_k=100, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return [SourceRecord(domain="profile", content=r.content, score=r.score, metadata=r.metadata) for r in raw] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _, raw = pipeline._fetch_profile_catalog(user_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pipeline._cached_profile_records = raw | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| SourceRecord(domain="profile", content=r.content, score=r.score, metadata={"id": r.id, **r.metadata}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for r in raw | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as exc: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.warning("Profile search error: %s", exc) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -763,6 +830,66 @@ async def _search_summary(pipeline: RetrievalPipeline, query: str, user_id: str, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def _search_snippet(pipeline: RetrievalPipeline, query: str, user_id: str, top_k: int) -> List[SourceRecord]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return await pipeline._search_snippet(query=query, user_id=user_id, top_k=top_k) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as exc: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.warning("Snippet search error: %s", exc) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def _search_code(req: SearchRequest, user_id: str) -> List[SourceRecord]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from src.api.routes.scanner import _get_code_store, _scanner_chat_allowed | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| denied = _scanner_chat_allowed(_get_code_store(), user_id, req.org_id or "", req.repo or "") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if denied: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.warning("Code search denied for user=%s org=%s repo=%s: %s", user_id, req.org_id, req.repo, denied) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pipeline = get_code_pipeline(org_id=req.org_id or "", repo=req.repo or "", project_id=req.project_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| symbol_results, file_results = await asyncio.gather( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pipeline._execute_tool( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tool_name="SearchSymbols", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tool_args={"query": req.query, "repo": req.repo or ""}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| repo=req.repo or "", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| top_k=req.top_k, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| user_id=user_id, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pipeline._execute_tool( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tool_name="SearchFiles", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tool_args={"query": req.query, "repo": req.repo or ""}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| repo=req.repo or "", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| top_k=req.top_k, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| user_id=user_id, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return [SourceRecord(domain="code", content=r.content, score=r.score, metadata=r.metadata) for r in [*symbol_results, *file_results]] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as exc: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.warning("Code search error: %s", exc) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def _synthesize_search_answer( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pipeline: RetrievalPipeline, query: str, sources: List[SourceRecord] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> str: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| context = "\n".join( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"[{idx + 1}] domain={source.domain} score={source.score:.3f}\n{source.content}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for idx, source in enumerate(sources) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) or "(No memory search results found.)" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| response = await pipeline.model.ainvoke([HumanMessage(content=ANSWER_PROMPT.format(context=context, query=query))]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| content = response.content if hasattr(response, "content") else response | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(content, list): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| parts = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for item in content: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(item, dict) and "text" in item: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| parts.append(item["text"]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| parts.append(str(item)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return "\n".join(parts) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return str(content) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # POST /v1/memory/scrape | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @scrape_router.post( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "/scrape", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ | |
|
|
||
| import asyncio | ||
| import logging | ||
| import time | ||
| from typing import Any, Callable, Dict, List, Optional | ||
|
|
||
| from dotenv import load_dotenv | ||
|
|
@@ -133,6 +134,10 @@ def __init__( | |
|
|
||
| self.embed_fn = embed_fn | ||
| self._snippet_stores: Dict[str, BaseVectorStore] = {} | ||
| self._profile_catalog_cache: Dict[str, tuple[float, List[Dict[str, str]], List[Any]]] = {} | ||
| self._profile_catalog_ttl_seconds = 60.0 | ||
| self._retrieval_plan_cache: Dict[tuple[str, str, str, int], tuple[float, AIMessage]] = {} | ||
| self._retrieval_plan_ttl_seconds = 30.0 | ||
|
|
||
| logger.info("RetrievalPipeline initialized") | ||
|
|
||
|
|
@@ -169,8 +174,15 @@ async def run( | |
| HumanMessage(content=query), | ||
| ] | ||
|
|
||
| ai_response: AIMessage = await self.model_with_tools.ainvoke(messages) | ||
| logger.info("LLM response received (tool_calls=%d)", len(ai_response.tool_calls or [])) | ||
| plan_cache_key = (user_id, query, catalog_text, top_k) | ||
| cached_plan = self._retrieval_plan_cache.get(plan_cache_key) | ||
| if cached_plan and time.monotonic() - cached_plan[0] < self._retrieval_plan_ttl_seconds: | ||
| ai_response = cached_plan[1] | ||
| logger.info("LLM retrieval plan cache hit (tool_calls=%d)", len(ai_response.tool_calls or [])) | ||
| else: | ||
| ai_response: AIMessage = await self.model_with_tools.ainvoke(messages) | ||
| self._retrieval_plan_cache[plan_cache_key] = (time.monotonic(), ai_response) | ||
|
Comment on lines
139
to
+184
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| logger.info("LLM response received (tool_calls=%d)", len(ai_response.tool_calls or [])) | ||
|
|
||
| # ── Step 2: Execute tool calls ──────────────────────────────── | ||
| sources: List[SourceRecord] = [] | ||
|
|
@@ -487,13 +499,18 @@ async def _search_snippet( | |
| # ------------------------------------------------------------------ | ||
|
|
||
| def _fetch_profile_catalog(self, user_id: str): | ||
| """Fetch all profile entries for a user. | ||
| """Fetch all profile entries for a user, caching the catalog briefly. | ||
|
|
||
| Returns: | ||
| (catalog, raw_results) | ||
| catalog — list of {topic, sub_topic} for the prompt | ||
| raw_results — the full SearchResult list, cached for _search_profile | ||
| """ | ||
| now = time.monotonic() | ||
| cached = self._profile_catalog_cache.get(user_id) | ||
| if cached and now - cached[0] < self._profile_catalog_ttl_seconds: | ||
| return cached[1], cached[2] | ||
|
|
||
| try: | ||
| results = self.vector_store.search_by_metadata( | ||
| filters={"user_id": user_id, "domain": "profile"}, | ||
|
|
@@ -524,6 +541,7 @@ def _fetch_profile_catalog(self, user_id: str): | |
| "sub_topic": "", | ||
| }) | ||
|
|
||
| self._profile_catalog_cache[user_id] = (now, catalog, results) | ||
| return catalog, results | ||
|
|
||
| def _format_catalog(self, catalog: List[Dict[str, str]]) -> str: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_record_search_latencycallslen(ordered)repeatedly inside the nestedpctclosure. Capturing it once asnavoids redundant calls and makes the guard condition symmetric.