diff --git a/backend/agents.py b/backend/agents.py index 80743e4..436d8d1 100644 --- a/backend/agents.py +++ b/backend/agents.py @@ -3,6 +3,7 @@ import re import json import asyncio +from datetime import datetime from enum import Enum from typing import Dict, List, Optional, TypedDict, Any @@ -11,7 +12,7 @@ from ks_search_tool import general_search, general_search_async, global_fuzzy_keyword_search from retrieval import Retriever -# LLM (Gemini) client setup +# LLM (Gemini) client setup try: from google import genai from google.genai import types as genai_types @@ -84,7 +85,7 @@ def _get_genai_client(): FLASH_LITE_MODEL = os.getenv("GEMINI_FLASH_LITE_MODEL", "gemini-2.5-flash-lite") -# Query intent/types +# Query intent/types class QueryIntent(Enum): DATA_DISCOVERY = "data_discovery" ACCESS_DOWNLOAD = "access_download" @@ -113,7 +114,7 @@ def _is_more_query(text: str) -> Optional[int]: return int(m.group(1)) if m else (None if any(w in t for w in ["more", "next", "continue"]) else None) -# LLM calls using google.genai +# LLM calls using google.genai async def call_gemini_for_keywords(query: str) -> List[str]: """ Extract raw keywords/phrases from the user's text using the LLM only. @@ -312,7 +313,7 @@ def _cfg(): return text -# Agent state and search/fuse/response pipeline +# Agent state and search/fuse/response pipeline class AgentState(TypedDict): session_id: str query: str @@ -360,21 +361,35 @@ async def run(self, query: str, want: int, context: Optional[Dict] = None) -> Li return [] try: # Run the synchronous search in a thread to make it async + # Use enhanced search with hybrid and re-ranking results = await asyncio.to_thread( - self.retriever.search, - query=query, - top_k=min(want, 50), - context={"raw": True} + self.retriever.search, + query=query, + top_k=min(want, 50), + context={"raw": True}, + use_hybrid=True, + use_rerank=True ) - return [item.__dict__ if hasattr(item, "__dict__") else item for item in results] + # Convert to dict format and include enhanced scores + result_dicts = [] + for item in results: + item_dict = item.__dict__.copy() + # Use the best available score for ranking (rerank > hybrid > similarity) + best_score = item.rerank_score or item.hybrid_score or item.similarity + item_dict['final_score'] = best_score + item_dict['similarity'] = item.similarity # Keep original similarity + result_dicts.append(item_dict) + return result_dicts except Exception as e: print(f"Vector search error: {e}") + import traceback + traceback.print_exc() return [] async def extract_keywords_and_rewrite(state: AgentState) -> AgentState: print("--- Node: Keywords, Rewrite, Intents ---") - # Detect intents on the raw input first + # Detect intents on the raw input first intents0 = await call_gemini_detect_intents(state["query"], state.get("history", [])) if intents0 == [QueryIntent.GREETING.value]: print("Pure greeting detected; skipping search.") @@ -405,42 +420,134 @@ async def execute_search(state: AgentState) -> Dict[str, Any]: print("Pure greeting; skipping search.") return {"ks_results": [], "vector_results": []} want_pool = 60 # collect enough for several pages (15 per page) - + # Run both searches simultaneously using shared vector agent ks_agent = KSSearchAgent() vec_agent = get_vector_agent() # Reuse the same instance - + ks_task = asyncio.create_task( ks_agent.run(state["effective_query"], state.get("keywords", []), want=want_pool) ) vec_task = asyncio.create_task( vec_agent.run(query=state["effective_query"], want=want_pool, context={"raw": True}) ) - + # Wait for both searches to complete ks_results_data, vec_results = await asyncio.gather(ks_task, vec_task) all_ks_results = ks_results_data.get("combined_results", []) - + print(f"Search completed: KS results={len(all_ks_results)}, Vector results={len(vec_results)}") return {"ks_results": all_ks_results, "vector_results": vec_results} +def calculate_advanced_score(result: dict, query: str, intents: List[str]) -> float: + """ + Calculate an advanced score based on multiple factors including: + - Original similarity/relevance score + - Metadata relevance to query + - Content relevance to query + - Source reliability (if available) + - Temporal factors (if available) + """ + import re + from typing import Any + + # Base score from original system + base_score = result.get("similarity", 0) if "similarity" in result else result.get("_score", 0) + + # Normalize base score to 0-1 range if needed + if base_score > 1: + base_score = min(base_score / 10.0, 1.0) # Assuming scores are typically 0-10 scale + + # Extract text content for relevance checking + title = result.get("title_guess", "") or result.get("title", "") or "" + content = result.get("content", "") or "" + metadata_text = "" + + # Extract metadata for relevance + metadata = result.get("metadata", {}) or {} + if isinstance(metadata, dict): + for key, value in metadata.items(): + if isinstance(value, (str, list)): + if isinstance(value, list): + value = " ".join(str(v) for v in value if isinstance(v, str)) + metadata_text += f" {value}" + + # Combine all text for relevance analysis + all_text = f"{title} {content} {metadata_text}".lower() + query_lower = query.lower() + + # Calculate query term matching + query_terms = query_lower.split() + matched_terms = 0 + total_terms = len(query_terms) + + for term in query_terms: + if term and term in all_text: + matched_terms += 1 + + # Term matching score (0-1) + term_match_score = matched_terms / total_terms if total_terms > 0 else 0.0 + + # Exact phrase matching bonus + phrase_bonus = 0.0 + if query_lower in all_text: + phrase_bonus = 0.2 # Bonus for exact phrase match + + # Intent-specific scoring + intent_bonus = 0.0 + if intents: + # Check if result contains intent-related keywords + if QueryIntent.ACCESS_DOWNLOAD.value in intents: + access_keywords = ["download", "access", "license", "api", "format"] + if any(keyword in all_text for keyword in access_keywords): + intent_bonus += 0.1 + + if QueryIntent.METADATA_QUERY.value in intents: + metadata_keywords = ["metadata", "format", "spec", "parameter", "method"] + if any(keyword in all_text for keyword in metadata_keywords): + intent_bonus += 0.1 + + # Combine scores with weights + # Base score: 50%, Term matching: 30%, Phrase bonus: 15%, Intent bonus: 5% + final_score = (base_score * 0.5) + (term_match_score * 0.3) + (phrase_bonus * 0.15) + (intent_bonus * 0.05) + + return min(final_score, 1.0) # Cap at 1.0 + + def fuse_results(state: AgentState) -> AgentState: print("--- Node: Result Fusion ---") ks_results = state.get("ks_results", []) vector_results = state.get("vector_results", []) + query = state.get("effective_query", "") + intents = state.get("intents", []) + combined: Dict[str, dict] = {} + + # Process vector results with advanced scoring for res in vector_results: if isinstance(res, dict): doc_id = res.get("id") or res.get("_id") or f"vec_{len(combined)}" - combined[doc_id] = {**res, "final_score": res.get("similarity", 0) * 0.6} + # Calculate advanced score for vector result + advanced_score = calculate_advanced_score(res, query, intents) + combined[doc_id] = {**res, "final_score": advanced_score, "source_type": "vector"} + + # Process KS results with advanced scoring for res in ks_results: if isinstance(res, dict): doc_id = res.get("_id") or res.get("id") or f"ks_{len(combined)}" if doc_id in combined: - combined[doc_id]["final_score"] += res.get("_score", 0) * 0.4 + # Combine scores if result exists from both sources + existing_score = combined[doc_id]["final_score"] + new_score = calculate_advanced_score(res, query, intents) + # Average the scores when result appears in both sources + combined[doc_id].update({**res, "final_score": (existing_score + new_score) / 2.0, "source_type": "combined"}) else: - combined[doc_id] = {**res, "final_score": res.get("_score", 0) * 0.4} + # Calculate advanced score for KS result + advanced_score = calculate_advanced_score(res, query, intents) + combined[doc_id] = {**res, "final_score": advanced_score, "source_type": "knowledge_space"} + + # Sort by final score in descending order all_sorted = sorted(combined.values(), key=lambda x: x.get("final_score", 0), reverse=True) print(f"Results summary: KS={len(ks_results)}, Vector={len(vector_results)}, Combined={len(all_sorted)}") page_size = 15 @@ -474,6 +581,10 @@ class NeuroscienceAssistant: def __init__(self): self.chat_history: Dict[str, List[str]] = {} self.session_memory: Dict[str, Dict[str, Any]] = {} + self.user_feedback: Dict[str, List[Dict]] = {} # Store user feedback for learning + self.entity_memory: Dict[str, Dict[str, Any]] = {} # Track entities mentioned in conversation + self.research_context: Dict[str, Dict[str, Any]] = {} # Track research threads and objectives + self.citation_tracker: Dict[str, Dict[str, List]] = {} # Track citations and references self.graph = self._build_graph() def _build_graph(self): @@ -492,7 +603,249 @@ def _build_graph(self): def reset_session(self, session_id: str): self.chat_history.pop(session_id, None) self.session_memory.pop(session_id, None) + self.entity_memory.pop(session_id, None) + self.research_context.pop(session_id, None) + self.citation_tracker.pop(session_id, None) + + def _extract_entities_from_query(self, query: str) -> List[str]: + """Extract key entities from user query using simple pattern matching""" + import re + + # Common neuroscience entities and patterns + patterns = [ + r'\b(human|mouse|rabbit|rat|monkey|drosophila|c\.elegans)\b', # species + r'\b(hippocampus|cortex|amygdala|thalamus|striatum|cerebellum|brainstem)\b', # brain regions + r'\b(EEG|fMRI|MRI|PET|MEG|DTI|NIRS)\b', # imaging techniques + r'\b(electrophysiology|patch clamp|calcium imaging|optogenetics)\b', # techniques + r'\b(BIDS|NWB|NIfTI|DICOM|HDF5)\b', # formats + r'\b(CC0|PDDL|license|open access)\b', # licenses + r'\b(gene|protein|receptor|channel)\b', # molecular + ] + + entities = [] + query_lower = query.lower() + for pattern in patterns: + matches = re.findall(pattern, query_lower) + entities.extend(matches) + + return list(set(entities)) # Remove duplicates + + def _update_entity_memory(self, session_id: str, query: str, results: List[dict]): + """Update entity memory based on query and results""" + if session_id not in self.entity_memory: + self.entity_memory[session_id] = {"entities": [], "preferences": {}, "context": {}} + + # Extract entities from query + query_entities = self._extract_entities_from_query(query) + self.entity_memory[session_id]["entities"].extend(query_entities) + + # Track user preferences based on results they engage with + for result in results[:3]: # Focus on top results + source = result.get("datasource_name", result.get("source", "unknown")) + if source not in self.entity_memory[session_id]["preferences"]: + self.entity_memory[session_id]["preferences"][source] = 0 + self.entity_memory[session_id]["preferences"][source] += 1 + + def _update_research_context(self, session_id: str, query: str, results: List[dict]): + """Update research context with objectives and thread information""" + if session_id not in self.research_context: + self.research_context[session_id] = { + "research_objectives": [], + "research_threads": [], + "current_focus": "", + "research_summary": "", + "next_steps": [] + } + + # Update research objectives based on query + if query.strip(): + objectives = self.research_context[session_id]["research_objectives"] + if query not in objectives: + objectives.append(query) + # Keep only the most recent 5 objectives + if len(objectives) > 5: + objectives.pop(0) + + # Track research threads + threads = self.research_context[session_id]["research_threads"] + if results: + # Add top result as part of current research thread + top_result = results[0] if results else {} + if top_result: + thread_entry = { + "query": query, + "result_id": top_result.get("id", ""), + "title": top_result.get("title_guess", ""), + "timestamp": datetime.now().isoformat() + } + threads.append(thread_entry) + # Keep only the most recent 10 thread entries + if len(threads) > 10: + threads.pop(0) + + # Update current focus + self.research_context[session_id]["current_focus"] = query + + def _update_citation_tracker(self, session_id: str, results: List[dict]): + """Track citations and references mentioned in conversation""" + if session_id not in self.citation_tracker: + self.citation_tracker[session_id] = { + "citations": [], + "referenced_datasets": [], + "accessed_datasets": [] + } + tracker = self.citation_tracker[session_id] + + for result in results: + dataset_id = result.get("id", "") + title = result.get("title_guess", "") + source = result.get("datasource_name", "") + + # Add to referenced datasets if not already there + ref_key = f"{dataset_id}:{title}" + if ref_key not in tracker["referenced_datasets"]: + tracker["referenced_datasets"].append(ref_key) + + # Add to accessed datasets (with timestamp) + access_entry = { + "dataset_id": dataset_id, + "title": title, + "source": source, + "timestamp": datetime.now().isoformat() + } + tracker["accessed_datasets"].append(access_entry) + + # Keep only the most recent 20 accessed datasets + if len(tracker["accessed_datasets"]) > 20: + tracker["accessed_datasets"] = tracker["accessed_datasets"][-20:] + + def _generate_research_summary(self, session_id: str) -> str: + """Generate a research summary for the current session""" + if session_id not in self.research_context: + return "" + + context = self.research_context[session_id] + summary_parts = [] + + # Add research objectives + if context["research_objectives"]: + summary_parts.append("Research Objectives:") + for i, obj in enumerate(context["research_objectives"][-3:], 1): # Last 3 objectives + summary_parts.append(f" {i}. {obj}") + + # Add current focus + if context["current_focus"]: + summary_parts.append(f"\nCurrent Focus: {context['current_focus']}") + + # Add accessed datasets + if session_id in self.citation_tracker: + accessed = self.citation_tracker[session_id]["accessed_datasets"] + if accessed: + summary_parts.append(f"\nAccessed Datasets ({len(accessed)}):") + for i, dataset in enumerate(accessed[-5:], 1): # Last 5 accessed + summary_parts.append(f" {i}. {dataset['title']} [{dataset['source']}]") + + return "\n".join(summary_parts) + + def _generate_next_steps_suggestions(self, session_id: str, query: str) -> List[str]: + """Generate next-step suggestions based on research context""" + suggestions = [] + + # Get research context + context = self.research_context.get(session_id, {}) + tracker = self.citation_tracker.get(session_id, {}) + + # If there are accessed datasets, suggest related exploration + if tracker.get("accessed_datasets"): + last_dataset = tracker["accessed_datasets"][-1] if tracker["accessed_datasets"] else {} + if last_dataset: + suggestions.append(f"Explore related datasets to '{last_dataset['title']}'") + suggestions.append(f"Look for datasets with similar methodology to '{last_dataset['title']}'") + + # If there are research objectives, suggest continuation + if context.get("research_objectives"): + last_objective = context["research_objectives"][-1] if context["research_objectives"] else "" + if last_objective: + suggestions.append(f"Continue research on '{last_objective}'") + suggestions.append(f"Find datasets that complement '{last_objective}'") + + # General suggestions based on query + if "human" in query.lower(): + suggestions.append("Explore human-specific datasets") + elif "mouse" in query.lower(): + suggestions.append("Look for mouse model datasets") + + if "fMRI" in query.upper() or "MRI" in query.upper(): + suggestions.append("Find related neuroimaging datasets") + + # Limit to 3 suggestions + return suggestions[:3] + + def _get_conversation_context(self, session_id: str) -> Dict[str, Any]: + """Get comprehensive conversation context for current session""" + context = { + "entities": self.entity_memory.get(session_id, {}).get("entities", []), + "preferences": self.entity_memory.get(session_id, {}).get("preferences", {}), + "recent_queries": [], + "research_context": self.research_context.get(session_id, {}), + "citations": self.citation_tracker.get(session_id, {}).get("referenced_datasets", []), + "research_summary": self._generate_research_summary(session_id) + } + + # Get recent queries from chat history + if session_id in self.chat_history: + recent_chats = self.chat_history[session_id][-6:] # Last 3 exchanges + for chat in recent_chats: + if chat.startswith("User:"): + context["recent_queries"].append(chat[6:]) # Remove "User: " prefix + + return context + + def _resolve_pronouns(self, query: str, session_id: str) -> str: + """Simple pronoun resolution to improve query understanding""" + context = self._get_conversation_context(session_id) + + # Replace pronouns with relevant entities if available + query_lower = query.lower() + + # If user says "these" or "those", try to map to recent entities + if "these" in query_lower or "those" in query_lower: + if context["entities"]: + # Replace with most recent entity + most_recent_entity = context["entities"][-1] if context["entities"] else "" + if most_recent_entity: + query = re.sub(r'\b(these|those)\b', most_recent_entity, query, flags=re.IGNORECASE) + + # If user says "it" or "this", try to map to recent entities + if "it" in query_lower or "this" in query_lower: + if context["entities"]: + most_recent_entity = context["entities"][-1] if context["entities"] else "" + if most_recent_entity: + query = re.sub(r'\b(it|this)\b', most_recent_entity, query, flags=re.IGNORECASE) + + return query + + def collect_feedback(self, session_id: str, query: str, response: str, result_ids: List[str], rating: int, feedback_text: str = ""): + """Collect user feedback to improve future results""" + feedback_entry = { + "session_id": session_id, + "query": query, + "response": response, + "result_ids": result_ids, + "rating": rating, # 1-5 scale + "feedback_text": feedback_text, + "timestamp": datetime.now().isoformat() + } + + if session_id not in self.user_feedback: + self.user_feedback[session_id] = [] + + self.user_feedback[session_id].append(feedback_entry) + + # Keep only recent feedback (last 100 entries per session) + if len(self.user_feedback[session_id]) > 100: + self.user_feedback[session_id] = self.user_feedback[session_id][-100:] async def handle_chat(self, session_id: str, query: str, reset: bool = False) -> str: try: @@ -501,6 +854,9 @@ async def handle_chat(self, session_id: str, query: str, reset: bool = False) -> if session_id not in self.chat_history: self.chat_history[session_id] = [] + # Resolve pronouns in query for better understanding + resolved_query = self._resolve_pronouns(query, session_id) + more_count = _is_more_query(query) mem = self.session_memory.get(session_id, {}) if more_count is not None or (query.strip().lower() in {"more", "next", "continue", "more please", "show more", "keep going"}): @@ -516,7 +872,7 @@ async def handle_chat(self, session_id: str, query: str, reset: bool = False) -> intents = mem.get("intents", [QueryIntent.DATA_DISCOVERY.value]) effective_query = mem.get("effective_query", "") prev_text = mem.get("last_text", "") - + text = await call_gemini_for_final_synthesis( effective_query, batch, intents, start_number=start + 1, previous_text=prev_text ) @@ -533,7 +889,7 @@ async def handle_chat(self, session_id: str, query: str, reset: bool = False) -> initial_state: AgentState = { "session_id": session_id, - "query": query, + "query": resolved_query, # Use resolved query "history": self.chat_history[session_id][-10:], "keywords": [], "effective_query": "", @@ -549,8 +905,18 @@ async def handle_chat(self, session_id: str, query: str, reset: bool = False) -> final_state = await self.graph.ainvoke(initial_state) response_text = final_state.get("final_response", "I encountered an unexpected empty response.") + # Update entity memory with query and results + all_results = final_state.get("all_results", []) + self._update_entity_memory(session_id, query, all_results) + + # Update research context + self._update_research_context(session_id, query, all_results) + + # Update citation tracker + self._update_citation_tracker(session_id, all_results) + self.session_memory[session_id] = { - "all_results": final_state.get("all_results", []), + "all_results": all_results, "page": 1, "page_size": 15, "effective_query": final_state.get("effective_query", initial_state["query"]), diff --git a/backend/retrieval.py b/backend/retrieval.py index 866d6c9..f70fd16 100644 --- a/backend/retrieval.py +++ b/backend/retrieval.py @@ -4,10 +4,13 @@ import logging from dataclasses import dataclass from typing import Any, Dict, List, Optional +import re +import math import torch from google.cloud import aiplatform, bigquery from transformers import AutoModel, AutoTokenizer +from sentence_transformers import CrossEncoder logger = logging.getLogger("retrieval") logger.setLevel(logging.INFO) @@ -26,12 +29,14 @@ class RetrievedItem: metadata: Dict[str, Any] primary_link: Optional[str] other_links: List[str] - similarity: float + similarity: float + rerank_score: Optional[float] = None # New field for re-ranking score + hybrid_score: Optional[float] = None # New field for hybrid search score -class Retriever: +class AdvancedRetriever: """ - Vertex AI Matching Engine retriever. + Advanced RAG system with hybrid search optimization and re-ranking. Environment variables required to enable vector search: - GCP_PROJECT_ID @@ -41,25 +46,28 @@ class Retriever: Optional: - EMBED_MODEL_NAME default: nomic-ai/nomic-embed-text-v1.5 + - RERANK_MODEL_NAME default: cross-encoder/ms-marco-MiniLM-L-6-v2 - BQ_DATASET_ID default: ks_metadata - BQ_TABLE_ID default: docstore - BQ_LOCATION default: US - EMBED_MAX_TOKENS default: 1024 - QUERY_CHAR_LIMIT default: 8000 + - HYBRID_ALPHA default: 0.7 (weight for vector vs keyword search) + - RERANK_TOP_K default: 50 (number of items to re-rank) """ def __init__(self): - self.project_id = os.getenv("GCP_PROJECT_ID", "") self.region = os.getenv("GCP_REGION", "") self.index_endpoint_full = os.getenv("INDEX_ENDPOINT_ID_FULL", "") self.deployed_id = os.getenv("DEPLOYED_INDEX_ID", "") - self.embed_model_name = os.getenv("EMBED_MODEL_NAME", "nomic-ai/nomic-embed-text-v1.5") + self.rerank_model_name = os.getenv("RERANK_MODEL_NAME", "cross-encoder/ms-marco-MiniLM-L-6-v2") self.bq_dataset = os.getenv("BQ_DATASET_ID", "ks_metadata") self.bq_table = os.getenv("BQ_TABLE_ID", "docstore") - self.bq_location = os.getenv("BQ_LOCATION","EU") + self.bq_location = os.getenv("BQ_LOCATION", "EU") + try: self.embed_max_tokens = int(os.getenv("EMBED_MAX_TOKENS", "1024")) except Exception: @@ -68,6 +76,14 @@ def __init__(self): self.query_char_limit = int(os.getenv("QUERY_CHAR_LIMIT", "8000")) except Exception: self.query_char_limit = 8000 + try: + self.hybrid_alpha = float(os.getenv("HYBRID_ALPHA", "0.7")) # Weight for vector vs keyword + except Exception: + self.hybrid_alpha = 0.7 + try: + self.rerank_top_k = int(os.getenv("RERANK_TOP_K", "50")) + except Exception: + self.rerank_top_k = 50 # Enable only if everything is present self.is_enabled = all( @@ -95,18 +111,24 @@ def __init__(self): try: self.device = "cuda" if torch.cuda.is_available() else "cpu" + # Initialize embedding model self.tokenizer = AutoTokenizer.from_pretrained( self.embed_model_name, trust_remote_code=True ) self.model = AutoModel.from_pretrained( self.embed_model_name, trust_remote_code=True ).eval().to(self.device) - logger.info(f"Vector search initialized on device={self.device} using {self.embed_model_name}") + + # Initialize re-ranking model + self.rerank_model = CrossEncoder(self.rerank_model_name, device=self.device) + + logger.info(f"Advanced RAG system initialized on device={self.device}") + logger.info(f"Embedding model: {self.embed_model_name}") + logger.info(f"Re-ranking model: {self.rerank_model_name}") except Exception as e: - logger.error(f"Embedding model initialization failed: {e}") + logger.error(f"Model initialization failed: {e}") self.is_enabled = False - # Embedding def _embed(self, text: str) -> List[float]: """ Returns a normalized embedding vector for the given text. @@ -132,7 +154,90 @@ def _embed(self, text: str) -> List[float]: rep = torch.nn.functional.normalize(rep, p=2, dim=1) return rep[0].cpu().tolist() - # BigQuery metadata + def _keyword_score(self, query: str, content: str) -> float: + """ + Calculate keyword-based relevance score using TF-IDF-like approach. + """ + if not query or not content: + return 0.0 + + query_lower = query.lower() + content_lower = content.lower() + + # Split into terms + query_terms = set(re.findall(r'\b\w+\b', query_lower)) + content_terms = re.findall(r'\b\w+\b', content_lower) + + # Calculate term frequency and matching + matches = 0 + total_terms = len(content_terms) + + for term in query_terms: + term_count = content_lower.count(term) + matches += term_count + + # Calculate a normalized score + if total_terms == 0: + return 0.0 + + # Use a logarithmic scale to prevent very long documents from dominating + keyword_score = matches / (math.log(1 + total_terms) + 1) + + # Boost score if exact phrase is found + if query_lower in content_lower: + keyword_score *= 1.5 + + return min(keyword_score, 1.0) # Cap at 1.0 + + def _hybrid_score(self, vector_score: float, keyword_score: float, alpha: float = 0.7) -> float: + """ + Combine vector similarity and keyword scores using weighted average. + """ + return alpha * vector_score + (1 - alpha) * keyword_score + + def _rerank_results(self, query: str, items: List[RetrievedItem]) -> List[RetrievedItem]: + """ + Re-rank results using a cross-encoder model for better relevance. + """ + if not items or len(items) == 0: + return items + + # Prepare pairs for cross-encoder (query, content) + pairs = [] + valid_items = [] + + for item in items: + content = f"{item.title_guess} {item.content}".strip() + if content: # Only process items with content + pairs.append([query, content]) + valid_items.append(item) + + if not pairs: + return items + + # Get re-ranking scores + try: + rerank_scores = self.rerank_model.predict(pairs) + + # Convert to list if needed + if hasattr(rerank_scores, 'tolist'): + rerank_scores = rerank_scores.tolist() + + # Update items with re-ranking scores + for i, score in enumerate(rerank_scores): + valid_items[i].rerank_score = float(score) + + except Exception as e: + logger.error(f"Re-ranking failed: {e}") + # If re-ranking fails, use original similarity scores + for item in valid_items: + item.rerank_score = item.similarity + + # Sort by re-ranking score in descending order + valid_items.sort(key=lambda x: x.rerank_score or x.similarity, reverse=True) + + return valid_items + def _bq_fetch(self, ids: List[str]) -> Dict[str, Dict[str, Any]]: if not ids: return {} @@ -162,11 +267,11 @@ def _bq_fetch(self, ids: List[str]) -> Dict[str, Dict[str, Any]]: return out def search( - self, query: str, top_k: int = 20, context: Optional[Dict[str, Any]] = None + self, query: str, top_k: int = 20, context: Optional[Dict[str, Any]] = None, + use_hybrid: bool = True, use_rerank: bool = True ) -> List[RetrievedItem]: """ - Executes a similarity search in Matching Engine. - + Executes an advanced search with hybrid optimization and re-ranking. """ if not self.is_enabled or not query: return [] @@ -180,7 +285,8 @@ def search( return [] try: - n = max(1, min((top_k or 20) * 2, 100)) + # Get more results than needed for re-ranking + n = max(1, min((top_k or 20) * 3, 100)) results = self.index_ep.find_neighbors( deployed_index_id=self.deployed_id, queries=[vec], num_neighbors=n ) @@ -213,6 +319,9 @@ def search( ) try: similarity = -float(dist) if dist is not None else 0.0 + # Normalize similarity to 0-1 range if needed + if similarity < 0: + similarity = 1.0 / (1.0 + math.exp(-similarity)) # Sigmoid normalization except Exception: similarity = 0.0 @@ -228,8 +337,59 @@ def search( ) ) - items.sort(key=lambda x: x.similarity, reverse=True) + # Apply hybrid search if enabled + if use_hybrid: + for item in items: + keyword_score = self._keyword_score(qtext, f"{item.title_guess} {item.content}") + item.hybrid_score = self._hybrid_score(item.similarity, keyword_score, self.hybrid_alpha) + + # Apply re-ranking if enabled + if use_rerank and len(items) > 0: + # Only re-rank top candidates to save computation + rerank_count = min(len(items), self.rerank_top_k) + top_items = items[:rerank_count] + remaining_items = items[rerank_count:] + + # Re-rank the top items + reranked_items = self._rerank_results(qtext, top_items) + + # Combine with remaining items (which keep their original order/scores) + items = reranked_items + remaining_items + + # Sort by the most appropriate score (rerank > hybrid > similarity) + items.sort(key=lambda x: ( + x.rerank_score or + x.hybrid_score or + x.similarity + ), reverse=True) + return items[: (top_k or 20)] except Exception as e: - logger.error(f"Matching Engine search failed: {e}") + logger.error(f"Advanced search failed: {e}") + import traceback + traceback.print_exc() + return [] + + +class Retriever: + """ + Legacy Retriever class for backward compatibility. + """ + def __init__(self): + self.advanced_retriever = AdvancedRetriever() + # Copy attributes for backward compatibility + self.is_enabled = self.advanced_retriever.is_enabled if self.advanced_retriever else False + self.project_id = getattr(self.advanced_retriever, 'project_id', '') + self.region = getattr(self.advanced_retriever, 'region', '') + self.index_endpoint_full = getattr(self.advanced_retriever, 'index_endpoint_full', '') + self.deployed_id = getattr(self.advanced_retriever, 'deployed_id', '') + + def search( + self, query: str, top_k: int = 20, context: Optional[Dict[str, Any]] = None + ) -> List[RetrievedItem]: + """ + Executes a similarity search using the advanced retriever. + """ + if not hasattr(self, 'advanced_retriever') or not self.advanced_retriever: return [] + return self.advanced_retriever.search(query, top_k, context, use_hybrid=True, use_rerank=True)