From 222df44d21db746bcf073a4891191851da3fd5ef Mon Sep 17 00:00:00 2001 From: "Charlie.Wei" Date: Mon, 17 Feb 2025 14:09:57 +0800 Subject: [PATCH] Retrieval Service efficiency optimization (#13543) --- api/configs/middleware/__init__.py | 6 + api/core/rag/datasource/retrieval_service.py | 288 +++++++++++-------- 2 files changed, 170 insertions(+), 124 deletions(-) diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index f6a44eaa471e62..af1d5d7497b566 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -1,3 +1,4 @@ +import os from typing import Any, Literal, Optional from urllib.parse import quote_plus @@ -166,6 +167,11 @@ def SQLALCHEMY_DATABASE_URI(self) -> str: default=False, ) + RETRIEVAL_SERVICE_WORKER: NonNegativeInt = Field( + description="If True, enables the retrieval service worker.", + default=os.cpu_count(), + ) + @computed_field def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: return { diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 927df0efc42706..07a905594430f0 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -1,9 +1,11 @@ +import concurrent.futures import json -import threading from typing import Optional from flask import Flask, current_app +from sqlalchemy.orm import load_only +from configs import dify_config from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector @@ -27,6 +29,7 @@ class RetrievalService: + # Cache precompiled regular expressions to avoid repeated compilation @classmethod def retrieve( cls, @@ -41,74 +44,62 @@ def retrieve( ): if not query: return [] - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() - if not dataset: - return [] - + dataset = cls._get_dataset(dataset_id) if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: return [] + all_documents: list[Document] = [] - threads: list[threading.Thread] = [] exceptions: list[str] = [] - # retrieval_model source with keyword - if retrieval_method == "keyword_search": - keyword_thread = threading.Thread( - target=RetrievalService.keyword_search, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "dataset_id": dataset_id, - "query": query, - "top_k": top_k, - "all_documents": all_documents, - "exceptions": exceptions, - }, - ) - threads.append(keyword_thread) - keyword_thread.start() - # retrieval_model source with semantic - if RetrievalMethod.is_support_semantic_search(retrieval_method): - embedding_thread = threading.Thread( - target=RetrievalService.embedding_search, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "dataset_id": dataset_id, - "query": query, - "top_k": top_k, - "score_threshold": score_threshold, - "reranking_model": reranking_model, - "all_documents": all_documents, - "retrieval_method": retrieval_method, - "exceptions": exceptions, - }, - ) - threads.append(embedding_thread) - embedding_thread.start() - - # retrieval source with full text - if RetrievalMethod.is_support_fulltext_search(retrieval_method): - full_text_index_thread = threading.Thread( - target=RetrievalService.full_text_index_search, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "dataset_id": dataset_id, - "query": query, - "retrieval_method": retrieval_method, - "score_threshold": score_threshold, - "top_k": top_k, - "reranking_model": reranking_model, - "all_documents": all_documents, - "exceptions": exceptions, - }, - ) - threads.append(full_text_index_thread) - full_text_index_thread.start() - for thread in threads: - thread.join() + # Optimize multithreading with thread pools + with concurrent.futures.ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_WORKER) as executor: # type: ignore + futures = [] + if retrieval_method == "keyword_search": + futures.append( + executor.submit( + cls.keyword_search, + flask_app=current_app._get_current_object(), # type: ignore + dataset_id=dataset_id, + query=query, + top_k=top_k, + all_documents=all_documents, + exceptions=exceptions, + ) + ) + if RetrievalMethod.is_support_semantic_search(retrieval_method): + futures.append( + executor.submit( + cls.embedding_search, + flask_app=current_app._get_current_object(), # type: ignore + dataset_id=dataset_id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + all_documents=all_documents, + retrieval_method=retrieval_method, + exceptions=exceptions, + ) + ) + if RetrievalMethod.is_support_fulltext_search(retrieval_method): + futures.append( + executor.submit( + cls.full_text_index_search, + flask_app=current_app._get_current_object(), # type: ignore + dataset_id=dataset_id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + all_documents=all_documents, + retrieval_method=retrieval_method, + exceptions=exceptions, + ) + ) + concurrent.futures.wait(futures, timeout=30, return_when=concurrent.futures.ALL_COMPLETED) if exceptions: - exception_message = ";\n".join(exceptions) - raise ValueError(exception_message) + raise ValueError(";\n".join(exceptions)) if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: data_post_processor = DataPostProcessor( @@ -133,18 +124,21 @@ def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model ) return all_documents + @classmethod + def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]: + return db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + @classmethod def keyword_search( cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list ): with flask_app.app_context(): try: - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = cls._get_dataset(dataset_id) if not dataset: raise ValueError("dataset not found") keyword = Keyword(dataset=dataset) - documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k) all_documents.extend(documents) except Exception as e: @@ -165,12 +159,11 @@ def embedding_search( ): with flask_app.app_context(): try: - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = cls._get_dataset(dataset_id) if not dataset: raise ValueError("dataset not found") vector = Vector(dataset=dataset) - documents = vector.search_by_vector( query, search_type="similarity_score_threshold", @@ -187,7 +180,7 @@ def embedding_search( and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value ): data_post_processor = DataPostProcessor( - str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False + str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False ) all_documents.extend( data_post_processor.invoke( @@ -217,13 +210,11 @@ def full_text_index_search( ): with flask_app.app_context(): try: - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = cls._get_dataset(dataset_id) if not dataset: raise ValueError("dataset not found") - vector_processor = Vector( - dataset=dataset, - ) + vector_processor = Vector(dataset=dataset) documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k) if documents: @@ -234,7 +225,7 @@ def full_text_index_search( and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value ): data_post_processor = DataPostProcessor( - str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False + str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False ) all_documents.extend( data_post_processor.invoke( @@ -253,64 +244,105 @@ def full_text_index_search( def escape_query_for_search(query: str) -> str: return json.dumps(query).strip('"') - @staticmethod - def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegments]: - records = [] - include_segment_ids = [] - segment_child_map = {} - for document in documents: - document_id = document.metadata.get("document_id") - dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() - if dataset_document: + @classmethod + def format_retrieval_documents(cls, documents: list[Document]) -> list[RetrievalSegments]: + """Format retrieval documents with optimized batch processing""" + if not documents: + return [] + + try: + # Collect document IDs + document_ids = {doc.metadata.get("document_id") for doc in documents if "document_id" in doc.metadata} + if not document_ids: + return [] + + # Batch query dataset documents + dataset_documents = { + doc.id: doc + for doc in db.session.query(DatasetDocument) + .filter(DatasetDocument.id.in_(document_ids)) + .options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id)) + .all() + } + + records = [] + include_segment_ids = set() + segment_child_map = {} + + # Process documents + for document in documents: + document_id = document.metadata.get("document_id") + if document_id not in dataset_documents: + continue + + dataset_document = dataset_documents[document_id] + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + # Handle parent-child documents child_index_node_id = document.metadata.get("doc_id") - result = ( - db.session.query(ChildChunk, DocumentSegment) - .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) + + child_chunk = ( + db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first() + ) + + if not child_chunk: + continue + + segment = ( + db.session.query(DocumentSegment) .filter( - ChildChunk.index_node_id == child_index_node_id, DocumentSegment.dataset_id == dataset_document.dataset_id, DocumentSegment.enabled == True, DocumentSegment.status == "completed", + DocumentSegment.id == child_chunk.segment_id, + ) + .options( + load_only( + DocumentSegment.id, + DocumentSegment.content, + DocumentSegment.answer, + DocumentSegment.doc_metadata, + ) ) .first() ) - if result: - child_chunk, segment = result - if not segment: - continue - if segment.id not in include_segment_ids: - include_segment_ids.append(segment.id) - child_chunk_detail = { - "id": child_chunk.id, - "content": child_chunk.content, - "position": child_chunk.position, - "score": document.metadata.get("score", 0.0), - } - map_detail = { - "max_score": document.metadata.get("score", 0.0), - "child_chunks": [child_chunk_detail], - } - segment_child_map[segment.id] = map_detail - record = { - "segment": segment, - } - records.append(record) - else: - child_chunk_detail = { - "id": child_chunk.id, - "content": child_chunk.content, - "position": child_chunk.position, - "score": document.metadata.get("score", 0.0), - } - segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) - segment_child_map[segment.id]["max_score"] = max( - segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) - ) - else: + + if not segment: continue + + if segment.id not in include_segment_ids: + include_segment_ids.add(segment.id) + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + map_detail = { + "max_score": document.metadata.get("score", 0.0), + "child_chunks": [child_chunk_detail], + } + segment_child_map[segment.id] = map_detail + record = { + "segment": segment, + } + records.append(record) + else: + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) + segment_child_map[segment.id]["max_score"] = max( + segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) + ) else: - index_node_id = document.metadata["doc_id"] + # Handle normal documents + index_node_id = document.metadata.get("doc_id") + if not index_node_id: + continue segment = ( db.session.query(DocumentSegment) @@ -325,16 +357,24 @@ def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegme if not segment: continue - include_segment_ids.append(segment.id) + + include_segment_ids.add(segment.id) record = { "segment": segment, - "score": document.metadata.get("score", None), + "score": document.metadata.get("score"), # type: ignore + "segment_metadata": segment.doc_metadata, } - records.append(record) + + # Add child chunks information to records for record in records: if record["segment"].id in segment_child_map: - record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None) + record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore record["score"] = segment_child_map[record["segment"].id]["max_score"] - return [RetrievalSegments(**record) for record in records] + return [RetrievalSegments(**record) for record in records] + except Exception as e: + db.session.rollback() + raise e + finally: + db.session.close()