diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index eb15b48ed..3eccb570a 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -1,4 +1,5 @@ import copy +import re import traceback from concurrent.futures import as_completed @@ -32,6 +33,7 @@ logger = get_logger(__name__) +KEYWORD_EXTRACT_TOP_K = 12 COT_DICT = { "fine": {"en": COT_PROMPT, "zh": COT_PROMPT_ZH}, "fast": {"en": SIMPLE_COT_PROMPT, "zh": SIMPLE_COT_PROMPT_ZH}, @@ -278,7 +280,7 @@ def _parse_task( # retrieve related nodes by embedding related_nodes = [ - self.graph_store.get_node(n["id"]) + self.graph_store.get_node(n["id"], user_name=user_name) for n in self.graph_store.search_by_embedding( query_embedding, top_k=top_k, @@ -505,6 +507,39 @@ def _retrieve_from_working_memory( search_filter=search_filter, ) + @staticmethod + def _require_keyword_user_name(user_name: str | None) -> str: + normalized_user_name = user_name.strip() if isinstance(user_name, str) else "" + if not normalized_user_name: + raise ValueError( + "[PATH-KEYWORD] user_name is required for PolarDB fulltext keyword search" + ) + return normalized_user_name + + def _extract_weighted_keyword_terms(self, query: str) -> list[str]: + if detect_lang(query) == "zh": + import jieba.analyse + + weighted_terms = jieba.analyse.extract_tags(query, topK=KEYWORD_EXTRACT_TOP_K) + else: + weighted_terms = [] + if self.tokenizer: + weighted_terms = self.tokenizer.tokenize_mixed(query) + else: + weighted_terms = re.findall(r"\b[a-zA-Z0-9]+\b", query.lower()) + + query_words: list[str] = [] + seen_words: set[str] = set() + for term in weighted_terms: + normalized_term = str(term).strip() + if not normalized_term or normalized_term in seen_words: + continue + seen_words.add(normalized_term) + query_words.append(normalized_term) + if len(query_words) >= KEYWORD_EXTRACT_TOP_K: + break + return query_words + @timed def _retrieve_from_keyword( self, @@ -524,22 +559,21 @@ def _retrieve_from_keyword( return [] if not query_embedding: return [] + user_name = self._require_keyword_user_name(user_name) - query_words: list[str] = [] - if self.tokenizer: - query_words = self.tokenizer.tokenize_mixed(query) - else: - query_words = query.strip().split() - # Use unique tokens; avoid passing the raw query into `to_tsquery(...)` because it may contain - # spaces/operators that cause tsquery parsing errors. - query_words = list(dict.fromkeys(query_words)) - if len(query_words) > 64: - query_words = query_words[:64] + query_words = self._extract_weighted_keyword_terms(query) if not query_words: return [] + # Quote weighted terms before `to_tsquery(...)` to avoid parsing operators from user input. tsquery_terms = ["'" + w.replace("'", "''") + "'" for w in query_words if w and w.strip()] if not tsquery_terms: return [] + logger.info( + "[PATH-KEYWORD] weighted query_words=%s top_k=%s user_name=%s", + query_words, + top_k, + user_name, + ) scopes = [memory_type] if memory_type != "All" else ["LongTermMemory", "UserMemory"] @@ -548,7 +582,7 @@ def _retrieve_from_keyword( try: hits = self.graph_store.search_by_fulltext( query_words=tsquery_terms, - top_k=top_k * 2, + top_k=top_k, status="activated", scope=scope, search_filter=None, diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 800343732..f84fc60e1 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -429,7 +429,9 @@ def extract_edge_info(edges_info: list[dict], neighbor_relativity: float): for edge in edges_info: chunk_target_id = edge.get("to") edge_type = edge.get("type") - item_neighbor = self.searcher.graph_store.get_node(chunk_target_id) + item_neighbor = self.searcher.graph_store.get_node( + chunk_target_id, user_name=user_name + ) if item_neighbor: item_neighbor_mem = TextualMemoryItem(**item_neighbor) item_neighbor_mem.metadata.relativity = neighbor_relativity