Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 46 additions & 12 deletions src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import re
import traceback

from concurrent.futures import as_completed
Expand Down Expand Up @@ -32,6 +33,7 @@


logger = get_logger(__name__)
KEYWORD_EXTRACT_TOP_K = 12
COT_DICT = {
"fine": {"en": COT_PROMPT, "zh": COT_PROMPT_ZH},
"fast": {"en": SIMPLE_COT_PROMPT, "zh": SIMPLE_COT_PROMPT_ZH},
Expand Down Expand Up @@ -278,7 +280,7 @@ def _parse_task(

# retrieve related nodes by embedding
related_nodes = [
self.graph_store.get_node(n["id"])
self.graph_store.get_node(n["id"], user_name=user_name)
for n in self.graph_store.search_by_embedding(
query_embedding,
top_k=top_k,
Expand Down Expand Up @@ -505,6 +507,39 @@ def _retrieve_from_working_memory(
search_filter=search_filter,
)

@staticmethod
def _require_keyword_user_name(user_name: str | None) -> str:
normalized_user_name = user_name.strip() if isinstance(user_name, str) else ""
if not normalized_user_name:
raise ValueError(
"[PATH-KEYWORD] user_name is required for PolarDB fulltext keyword search"
)
return normalized_user_name

def _extract_weighted_keyword_terms(self, query: str) -> list[str]:
if detect_lang(query) == "zh":
import jieba.analyse

weighted_terms = jieba.analyse.extract_tags(query, topK=KEYWORD_EXTRACT_TOP_K)
else:
weighted_terms = []
if self.tokenizer:
weighted_terms = self.tokenizer.tokenize_mixed(query)
else:
weighted_terms = re.findall(r"\b[a-zA-Z0-9]+\b", query.lower())

query_words: list[str] = []
seen_words: set[str] = set()
for term in weighted_terms:
normalized_term = str(term).strip()
if not normalized_term or normalized_term in seen_words:
continue
seen_words.add(normalized_term)
query_words.append(normalized_term)
if len(query_words) >= KEYWORD_EXTRACT_TOP_K:
break
return query_words

@timed
def _retrieve_from_keyword(
self,
Expand All @@ -524,22 +559,21 @@ def _retrieve_from_keyword(
return []
if not query_embedding:
return []
user_name = self._require_keyword_user_name(user_name)

query_words: list[str] = []
if self.tokenizer:
query_words = self.tokenizer.tokenize_mixed(query)
else:
query_words = query.strip().split()
# Use unique tokens; avoid passing the raw query into `to_tsquery(...)` because it may contain
# spaces/operators that cause tsquery parsing errors.
query_words = list(dict.fromkeys(query_words))
if len(query_words) > 64:
query_words = query_words[:64]
query_words = self._extract_weighted_keyword_terms(query)
if not query_words:
return []
# Quote weighted terms before `to_tsquery(...)` to avoid parsing operators from user input.
tsquery_terms = ["'" + w.replace("'", "''") + "'" for w in query_words if w and w.strip()]
if not tsquery_terms:
return []
logger.info(
"[PATH-KEYWORD] weighted query_words=%s top_k=%s user_name=%s",
query_words,
top_k,
user_name,
)

scopes = [memory_type] if memory_type != "All" else ["LongTermMemory", "UserMemory"]

Expand All @@ -548,7 +582,7 @@ def _retrieve_from_keyword(
try:
hits = self.graph_store.search_by_fulltext(
query_words=tsquery_terms,
top_k=top_k * 2,
top_k=top_k,
status="activated",
scope=scope,
search_filter=None,
Expand Down
4 changes: 3 additions & 1 deletion src/memos/multi_mem_cube/single_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,9 @@ def extract_edge_info(edges_info: list[dict], neighbor_relativity: float):
for edge in edges_info:
chunk_target_id = edge.get("to")
edge_type = edge.get("type")
item_neighbor = self.searcher.graph_store.get_node(chunk_target_id)
item_neighbor = self.searcher.graph_store.get_node(
chunk_target_id, user_name=user_name
)
if item_neighbor:
item_neighbor_mem = TextualMemoryItem(**item_neighbor)
item_neighbor_mem.metadata.relativity = neighbor_relativity
Expand Down
Loading