Skip to content

Commit f1daca9

Browse files
committed
feat:[DocAgent] Dynamic RAG
1 parent d646be4 commit f1daca9

File tree

2 files changed

+271
-33
lines changed

2 files changed

+271
-33
lines changed

autogen/agents/experimental/document_agent/document_agent.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(
8484
parsed_docs_path: str | Path | None = None,
8585
collection_name: str | None = None,
8686
query_engine: RAGQueryEngine | None = None,
87+
rag_config: dict[str, dict[str, Any]] | None = None, # NEW: {"vector": {}, "graph": {...}}
8788
):
8889
"""Initialize the DocAgent.
8990
@@ -94,6 +95,7 @@ def __init__(
9495
parsed_docs_path: The path where parsed documents will be stored.
9596
collection_name: The unique name for the data store collection.
9697
query_engine: The query engine to use for querying documents.
98+
rag_config: Configuration for RAG engines {"vector": {}, "graph": {...}}.
9799
"""
98100
name = name or "DocAgent"
99101
llm_config = llm_config or LLMConfig.get_current_llm_config()
@@ -120,6 +122,7 @@ def __init__(
120122
query_engine=query_engine,
121123
parsed_docs_path=parsed_docs_path,
122124
collection_name=collection_name,
125+
rag_config=rag_config, # NEW
123126
)
124127

125128
def update_ingested_documents() -> None:

autogen/agents/experimental/document_agent/task_manager.py

Lines changed: 268 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,18 @@
44

55
import asyncio
66
import logging
7+
import tempfile
78
from concurrent.futures import ThreadPoolExecutor
89
from pathlib import Path
910
from typing import Any
1011

12+
import fitz # PyMuPDF # pyright: ignore[reportMissingImports]
13+
import requests
14+
import urllib3
15+
1116
from autogen import ConversableAgent
17+
from autogen.agentchat.contrib.capabilities.text_compressors import LLMLingua
18+
from autogen.agentchat.contrib.capabilities.transforms import TextMessageCompressor
1219
from autogen.agentchat.contrib.rag.query_engine import RAGQueryEngine
1320
from autogen.agentchat.group.context_variables import ContextVariables
1421
from autogen.agentchat.group.reply_result import ReplyResult
@@ -43,6 +50,33 @@
4350
"""
4451

4552

53+
def extract_text_from_pdf(doc_path: str) -> list[dict[str, str]]:
54+
"""Extract compressed text from a PDF file"""
55+
if isinstance(doc_path, str) and urllib3.util.url.parse_url(doc_path).scheme:
56+
# Download the PDF
57+
response = requests.get(doc_path)
58+
response.raise_for_status() # Ensure the download was successful
59+
60+
text = ""
61+
# Save the PDF to a temporary file
62+
with tempfile.TemporaryDirectory() as temp_dir:
63+
with open(temp_dir + "temp.pdf", "wb") as f:
64+
f.write(response.content)
65+
66+
# Open the PDF
67+
with fitz.open(temp_dir + "temp.pdf") as doc:
68+
# Read and extract text from each page
69+
for page in doc:
70+
text += page.get_text()
71+
llm_lingua = LLMLingua()
72+
text_compressor = TextMessageCompressor(text_compressor=llm_lingua)
73+
compressed_text = text_compressor.apply_transform([{"content": text}])
74+
75+
return compressed_text
76+
else:
77+
raise ValueError("doc_path must be a string or a URL")
78+
79+
4680
@export_module("autogen.agents.experimental")
4781
class TaskManagerAgent(ConversableAgent):
4882
"""TaskManagerAgent with integrated tools for document ingestion and query processing."""
@@ -57,6 +91,7 @@ def __init__(
5791
return_agent_error: str = "SummaryAgent",
5892
collection_name: str | None = None,
5993
max_workers: int | None = None,
94+
rag_config: dict[str, dict[str, Any]] | None = None,
6095
):
6196
"""Initialize the TaskManagerAgent.
6297
@@ -69,55 +104,222 @@ def __init__(
69104
return_agent_error: The agent to return on error
70105
collection_name: The collection name for the RAG query engine
71106
max_workers: Maximum number of threads for concurrent processing (None for default)
107+
rag_config: Configuration for RAG engines {"vector": {}, "graph": {...}}
72108
"""
73-
self.query_engine = query_engine if query_engine else VectorChromaQueryEngine(collection_name=collection_name)
109+
self.rag_config = rag_config or {"vector": {}} # Default to vector only
74110
self.parsed_docs_path = Path(parsed_docs_path) if parsed_docs_path else Path("./parsed_docs")
75111
self.executor = ThreadPoolExecutor(max_workers=max_workers)
76112

113+
# Initialize RAG engines
114+
self.rag_engines = self._create_rag_engines(collection_name)
115+
116+
# Keep backward compatibility
117+
self.query_engine = query_engine if query_engine else self.rag_engines.get("vector")
118+
119+
def _aggregate_rag_results(self: "TaskManagerAgent", query: str, results: dict[str, Any]) -> str:
120+
"""Aggregate results from multiple RAG engines."""
121+
if not results:
122+
return f"Query: {query}\nAnswer: No results found from any RAG engine."
123+
124+
# Simple aggregation
125+
answer_parts = [f"Query: {query}"]
126+
127+
for engine_name, result in results.items():
128+
answer_parts.append(f"\n{engine_name.upper()} Results:")
129+
answer_parts.append(f"Answer: {result.get('answer', 'No answer available')}")
130+
131+
# Add citations if available
132+
if "citations" in result and result["citations"]:
133+
answer_parts.append("Citations:")
134+
for i, citation in enumerate(result["citations"], 1):
135+
answer_parts.append(f" [{i}] {citation.get('file_path', 'Unknown')}")
136+
137+
return "\n".join(answer_parts)
138+
77139
def _process_single_document(self: "TaskManagerAgent", input_file_path: str) -> tuple[str, bool, str]:
78140
"""Process a single document. Returns (path, success, error_msg)."""
141+
142+
def compress_and_save_text(text: str, input_path: str) -> str:
143+
"""Compress text and save as markdown file."""
144+
llm_lingua = LLMLingua()
145+
text_compressor = TextMessageCompressor(text_compressor=llm_lingua)
146+
compressed_text = text_compressor.apply_transform([{"content": text}])
147+
148+
# Create a markdown file with the extracted text
149+
output_file = self.parsed_docs_path / f"{Path(input_path).stem}.md"
150+
self.parsed_docs_path.mkdir(parents=True, exist_ok=True)
151+
152+
with open(output_file, "w", encoding="utf-8") as f:
153+
f.write(compressed_text[0]["content"])
154+
155+
return str(output_file)
156+
157+
def ingest_to_engines(self: "TaskManagerAgent", output_file: str, input_path: str) -> None:
158+
"""Ingest document to configured RAG engines."""
159+
from autogen.agentchat.contrib.graph_rag.document import Document, DocumentType
160+
161+
# Determine document type
162+
doc_type = DocumentType.TEXT
163+
if input_path.lower().endswith(".pdf"):
164+
doc_type = DocumentType.PDF
165+
elif input_path.lower().endswith((".html", ".htm")):
166+
doc_type = DocumentType.HTML
167+
elif input_path.lower().endswith(".json"):
168+
doc_type = DocumentType.JSON
169+
170+
# Create Document object for graph engines
171+
graph_doc = Document(doctype=doc_type, path_or_url=output_file, data=None)
172+
173+
# Ingest to configured engines only
174+
for rag_type in self.rag_config.keys():
175+
engine = self.rag_engines.get(rag_type)
176+
if engine is None:
177+
continue
178+
179+
try:
180+
if rag_type == "vector":
181+
engine.add_docs(new_doc_paths_or_urls=[output_file])
182+
elif rag_type == "graph":
183+
# For graph engines, we need to initialize if not done already
184+
if not hasattr(engine, "_initialized"):
185+
engine.init_db([graph_doc])
186+
engine._initialized = True
187+
else:
188+
# Add new records to existing graph
189+
if hasattr(engine, "add_records"):
190+
engine.add_records([graph_doc])
191+
except Exception as e:
192+
logger.warning(f"Failed to ingest to {rag_type} engine: {e}")
193+
79194
try:
80-
output_files = docling_parse_docs(
81-
input_file_path=input_file_path,
82-
output_dir_path=self.parsed_docs_path,
83-
output_formats=["markdown"],
84-
)
195+
# Check if the document is a PDF
196+
is_pdf = False
197+
if isinstance(input_file_path, str) and (
198+
input_file_path.lower().endswith(".pdf")
199+
or (urllib3.util.url.parse_url(input_file_path).scheme and input_file_path.lower().endswith(".pdf"))
200+
):
201+
# Check for PDF extension or URL ending with .pdf
202+
is_pdf = True
203+
204+
if is_pdf:
205+
# Handle PDF with PyMuPDF
206+
print("PDF found using PyMuPDF")
207+
if urllib3.util.url.parse_url(input_file_path).scheme:
208+
# Download the PDF
209+
response = requests.get(input_file_path)
210+
response.raise_for_status()
211+
212+
text = ""
213+
# Save the PDF to a temporary file and extract text
214+
with tempfile.TemporaryDirectory() as temp_dir:
215+
temp_pdf_path = Path(temp_dir) / "temp.pdf"
216+
with open(temp_pdf_path, "wb") as f:
217+
f.write(response.content)
218+
219+
# Open the PDF and extract text
220+
with fitz.open(temp_pdf_path) as doc:
221+
for page in doc:
222+
text += page.get_text()
223+
224+
# Compress and save
225+
output_file = compress_and_save_text(text, input_file_path)
226+
227+
# Ingest to all active engines
228+
ingest_to_engines(self, output_file, input_file_path)
229+
230+
return (input_file_path, True, "")
231+
else:
232+
# Local PDF file
233+
text = ""
234+
with fitz.open(input_file_path) as doc:
235+
for page in doc:
236+
text += page.get_text()
237+
238+
# Compress and save
239+
output_file = compress_and_save_text(text, input_file_path)
240+
241+
# Ingest to all active engines
242+
ingest_to_engines(self, output_file, input_file_path)
85243

86-
# Limit to one output markdown file for now.
87-
if output_files:
88-
output_file = output_files[0]
89-
if output_file.suffix == ".md":
90-
self.query_engine.add_docs(new_doc_paths_or_urls=[output_file])
91244
return (input_file_path, True, "")
245+
else:
246+
# Handle non-PDF documents with docling
247+
output_files = docling_parse_docs(
248+
input_file_path=input_file_path,
249+
output_dir_path=self.parsed_docs_path,
250+
output_formats=["markdown"],
251+
)
252+
253+
# Limit to one output markdown file for now.
254+
if output_files:
255+
parsed_output_file: Path = output_files[0]
256+
if parsed_output_file.suffix == ".md":
257+
# Ingest to all active engines
258+
ingest_to_engines(self, str(parsed_output_file), input_file_path)
259+
return (input_file_path, True, "")
260+
261+
return (input_file_path, False, "No valid markdown output generated")
92262

93-
return (input_file_path, False, "No valid markdown output generated")
94263
except Exception as doc_error:
95264
return (input_file_path, False, str(doc_error))
96265

97266
def _execute_single_query(self: "TaskManagerAgent", query_text: str) -> tuple[str, str]:
98-
"""Execute a single query. Returns (query, result)."""
267+
"""Execute a single query across configured RAG engines. Returns (query, result)."""
99268
try:
100-
# Check for citations support
101-
if (
102-
hasattr(self.query_engine, "enable_query_citations")
103-
and getattr(self.query_engine, "enable_query_citations", False)
104-
and hasattr(self.query_engine, "query_with_citations")
105-
and callable(getattr(self.query_engine, "query_with_citations", None))
106-
):
107-
answer_with_citations = getattr(self.query_engine, "query_with_citations")(query_text)
108-
answer = answer_with_citations.answer
109-
txt_citations = [
110-
{
111-
"text_chunk": source.node.get_text(),
112-
"file_path": source.metadata.get("file_path", "Unknown"),
113-
}
114-
for source in answer_with_citations.citations
115-
]
116-
logger.info(f"Citations: {txt_citations}")
117-
else:
118-
answer = self.query_engine.query(query_text) if self.query_engine else "Query engine not available"
269+
results = {}
270+
271+
# Only query engines that are configured in rag_config
272+
for rag_type in self.rag_config.keys():
273+
engine = self.rag_engines.get(rag_type)
274+
if engine is None:
275+
continue
119276

120-
return (query_text, f"Query: {query_text}\nAnswer: {answer}")
277+
try:
278+
if rag_type == "vector":
279+
# Handle vector queries
280+
if (
281+
hasattr(engine, "enable_query_citations")
282+
and getattr(engine, "enable_query_citations", False)
283+
and hasattr(engine, "query_with_citations")
284+
and callable(getattr(engine, "query_with_citations", None))
285+
):
286+
answer_with_citations = getattr(engine, "query_with_citations")(query_text)
287+
answer = answer_with_citations.answer
288+
txt_citations = [
289+
{
290+
"text_chunk": source.node.get_text(),
291+
"file_path": source.metadata.get("file_path", "Unknown"),
292+
}
293+
for source in answer_with_citations.citations
294+
]
295+
results[rag_type] = {"answer": answer, "citations": txt_citations}
296+
logger.info(f"Vector Citations: {txt_citations}")
297+
else:
298+
answer = engine.query(query_text) if engine else "Vector engine not available"
299+
results[rag_type] = {"answer": answer}
300+
301+
elif rag_type == "graph":
302+
# Handle graph queries
303+
# Try to connect to existing graph if not already connected
304+
if not hasattr(engine, "index"):
305+
try:
306+
engine.connect_db()
307+
logger.info("Connected to existing Neo4j graph for querying")
308+
except Exception as connect_error:
309+
logger.warning(f"Failed to connect to Neo4j graph: {connect_error}")
310+
results[rag_type] = {"answer": f"Error connecting to graph: {connect_error}"}
311+
continue
312+
313+
graph_result = engine.query(query_text)
314+
results[rag_type] = {"answer": graph_result.answer, "results": graph_result.results}
315+
316+
except Exception as engine_error:
317+
logger.warning(f"Failed to query {rag_type} engine: {engine_error}")
318+
results[rag_type] = {"answer": f"Error querying {rag_type}: {engine_error}"}
319+
320+
# Aggregate results
321+
aggregated_answer = _aggregate_rag_results(self, query_text, results)
322+
return (query_text, aggregated_answer)
121323

122324
except Exception as query_error:
123325
logger.warning(f"Failed to execute query '{query_text}': {query_error}")
@@ -306,3 +508,36 @@ def __del__(self) -> None:
306508
"""Clean up the ThreadPoolExecutor when the agent is destroyed."""
307509
if hasattr(self, "executor"):
308510
self.executor.shutdown(wait=True)
511+
512+
def _create_rag_engines(self, collection_name: str | None = None) -> dict[str, Any]:
513+
"""Create RAG engines based on rag_config."""
514+
engines = {}
515+
516+
for rag_type, config in self.rag_config.items():
517+
if rag_type == "vector":
518+
engines["vector"] = VectorChromaQueryEngine(
519+
collection_name=config.get("collection_name", collection_name),
520+
**{k: v for k, v in config.items() if k != "collection_name"},
521+
)
522+
elif rag_type == "graph":
523+
engines["graph"] = self._create_neo4j_engine(config)
524+
525+
return engines
526+
527+
def _create_neo4j_engine(self, config: dict[str, Any]) -> Any:
528+
"""Create Neo4j graph query engine."""
529+
try:
530+
from autogen.agentchat.contrib.graph_rag.neo4j_graph_query_engine import Neo4jGraphQueryEngine
531+
532+
return Neo4jGraphQueryEngine(
533+
host=config.get("host", "bolt://localhost"),
534+
port=config.get("port", 7687),
535+
database=config.get("database", "neo4j"),
536+
username=config.get("username", "neo4j"),
537+
password=config.get("password", "neo4j"),
538+
llm=config.get("llm"),
539+
embedding=config.get("embedding"),
540+
)
541+
except ImportError as e:
542+
logger.warning(f"Neo4j dependencies not available: {e}. Skipping graph engine.")
543+
return None

0 commit comments

Comments
 (0)