Skip to content

Commit b8354ac

Browse files
committed
fix: add type safety , validation, thread safety
1 parent c67851f commit b8354ac

File tree

3 files changed

+142
-93
lines changed

3 files changed

+142
-93
lines changed

autogen/agents/experimental/document_agent/document_agent.py

Lines changed: 86 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def create_summary_agent_prompt(agent: ConversableAgent, messages: list[dict[str
162162
"""Create the summary agent prompt with context information."""
163163
update_ingested_documents()
164164

165+
# Safe type casting with defaults
165166
query_results = cast(list[dict[str, Any]], agent.context_variables.get("QueryResults", []))
166167
documents_ingested = cast(list[str], agent.context_variables.get("DocumentsIngested", []))
167168
documents_to_ingest = cast(list[Ingest], agent.context_variables.get("DocumentsToIngest", []))
@@ -209,86 +210,94 @@ def generate_inner_group_chat_reply(
209210
config: Any = None,
210211
) -> tuple[bool, str | dict[str, Any] | None]:
211212
"""Reply function that generates the inner group chat reply for the DocAgent."""
212-
# Initialize or reuse context variables
213-
if hasattr(self, "_group_chat_context_variables") and self._group_chat_context_variables is not None:
214-
context_variables = self._group_chat_context_variables
215-
# Reset pending tasks for new run
216-
context_variables["DocumentsToIngest"] = []
217-
else:
218-
context_variables = ContextVariables(
219-
data={
220-
"CompletedTaskCount": 0,
221-
"DocumentsToIngest": [],
222-
"DocumentsIngested": self.documents_ingested,
223-
"QueriesToRun": [],
224-
"QueryResults": [],
225-
}
213+
try:
214+
# Initialize or reuse context variables
215+
if hasattr(self, "_group_chat_context_variables") and self._group_chat_context_variables is not None:
216+
context_variables = self._group_chat_context_variables
217+
# Reset pending tasks for new run
218+
context_variables["DocumentsToIngest"] = []
219+
else:
220+
context_variables = ContextVariables(
221+
data={
222+
"CompletedTaskCount": 0,
223+
"DocumentsToIngest": [],
224+
"DocumentsIngested": self.documents_ingested,
225+
"QueriesToRun": [],
226+
"QueryResults": [],
227+
}
228+
)
229+
self._group_chat_context_variables = context_variables
230+
231+
if messages and len(messages) > 0:
232+
last_message = messages[-1]
233+
if (
234+
isinstance(last_message, dict)
235+
and last_message.get("name") == "DocumentTriageAgent"
236+
and "content" in last_message
237+
and isinstance(last_message["content"], str)
238+
):
239+
try:
240+
import json
241+
242+
document_task_data = json.loads(last_message["content"])
243+
244+
# Extract ingestions and queries
245+
ingestions = [Ingest(**ing) for ing in document_task_data.get("ingestions", [])]
246+
queries = [Query(**q) for q in document_task_data.get("queries", [])]
247+
248+
# Update context variables with new tasks
249+
existing_ingestions = context_variables.get("DocumentsToIngest", []) or []
250+
existing_queries = context_variables.get("QueriesToRun", []) or []
251+
documents_ingested = context_variables.get("DocumentsIngested", []) or []
252+
253+
# Deduplicate and add new ingestions
254+
for ingestion in ingestions:
255+
if (
256+
ingestion.path_or_url not in [ing.path_or_url for ing in existing_ingestions]
257+
and ingestion.path_or_url not in documents_ingested
258+
):
259+
existing_ingestions.append(ingestion)
260+
261+
# Deduplicate and add new queries
262+
for query in queries:
263+
if query.query not in [q.query for q in existing_queries]:
264+
existing_queries.append(query)
265+
266+
context_variables["DocumentsToIngest"] = existing_ingestions
267+
context_variables["QueriesToRun"] = existing_queries
268+
context_variables["TaskInitiated"] = True
269+
270+
logger.info(f"Processed triage output: {len(ingestions)} ingestions, {len(queries)} queries")
271+
272+
except json.JSONDecodeError as e:
273+
logger.warning(f"Failed to parse triage output JSON: {e}")
274+
except Exception as e:
275+
logger.warning(f"Failed to process triage output: {e}")
276+
277+
group_chat_agents = [
278+
self._triage_agent,
279+
self._task_manager_agent,
280+
self._summary_agent,
281+
]
282+
283+
agent_pattern = DefaultPattern(
284+
initial_agent=self._triage_agent,
285+
agents=group_chat_agents,
286+
context_variables=context_variables,
287+
group_after_work=TerminateTarget(),
226288
)
227-
self._group_chat_context_variables = context_variables
228-
229-
if messages and len(messages) > 0:
230-
last_message = messages[-1]
231-
if (
232-
isinstance(last_message, dict)
233-
and last_message.get("name") == "DocumentTriageAgent"
234-
and "content" in last_message
235-
):
236-
try:
237-
import json
238-
239-
document_task_data = json.loads(last_message["content"])
240-
241-
# Extract ingestions and queries
242-
ingestions = [Ingest(**ing) for ing in document_task_data.get("ingestions", [])]
243-
queries = [Query(**q) for q in document_task_data.get("queries", [])]
244-
245-
# Update context variables with new tasks
246-
existing_ingestions = context_variables.get("DocumentsToIngest", []) or []
247-
existing_queries = context_variables.get("QueriesToRun", []) or []
248-
documents_ingested = context_variables.get("DocumentsIngested", []) or []
249-
250-
# Deduplicate and add new ingestions
251-
for ingestion in ingestions:
252-
if (
253-
ingestion.path_or_url not in [ing.path_or_url for ing in existing_ingestions]
254-
and ingestion.path_or_url not in documents_ingested
255-
):
256-
existing_ingestions.append(ingestion)
257-
258-
# Deduplicate and add new queries
259-
for query in queries:
260-
if query.query not in [q.query for q in existing_queries]:
261-
existing_queries.append(query)
262-
263-
context_variables["DocumentsToIngest"] = existing_ingestions
264-
context_variables["QueriesToRun"] = existing_queries
265-
context_variables["TaskInitiated"] = True
266-
267-
logger.info(f"Processed triage output: {len(ingestions)} ingestions, {len(queries)} queries")
268-
269-
except Exception as e:
270-
logger.warning(f"Failed to process triage output: {e}")
271-
272-
group_chat_agents = [
273-
self._triage_agent,
274-
self._task_manager_agent,
275-
self._summary_agent,
276-
]
277-
278-
agent_pattern = DefaultPattern(
279-
initial_agent=self._triage_agent,
280-
agents=group_chat_agents,
281-
context_variables=context_variables,
282-
group_after_work=TerminateTarget(),
283-
)
284289

285-
chat_result, context_variables, last_speaker = initiate_group_chat(
286-
pattern=agent_pattern,
287-
messages=self._get_document_input_message(messages),
288-
)
290+
chat_result, context_variables, last_speaker = initiate_group_chat(
291+
pattern=agent_pattern,
292+
messages=self._get_document_input_message(messages),
293+
)
294+
295+
# Always return the final result since we only have summary termination
296+
return True, chat_result.summary
289297

290-
# Always return the final result since we only have summary termination
291-
return True, chat_result.summary
298+
except Exception as e:
299+
logger.error(f"Critical error in DocAgent group chat: {e}")
300+
return True, f"Error processing request: {str(e)}"
292301

293302
def _get_document_input_message(self, messages: list[dict[str, Any]] | None) -> str:
294303
"""Gets and validates the input message(s) for the document agent."""

autogen/agents/experimental/document_agent/task_manager.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import asyncio
66
import logging
7+
import threading
78
from concurrent.futures import ThreadPoolExecutor
89
from pathlib import Path
910
from typing import Any
@@ -81,6 +82,7 @@ def __init__(
8182
self.parsed_docs_path = Path(parsed_docs_path) if parsed_docs_path else Path("./parsed_docs")
8283
self.executor = ThreadPoolExecutor(max_workers=max_workers)
8384
self._temp_citations_store: dict[str, list[dict[str, str]]] = {}
85+
self._context_lock = threading.Lock()
8486

8587
# Initialize RAG engines
8688
self.rag_engines = self._create_rag_engines(collection_name)
@@ -100,6 +102,27 @@ async def ingest_documents(
100102
Returns:
101103
str: Status message about the ingestion process
102104
"""
105+
# Add input validation
106+
if not documents_to_ingest:
107+
return ReplyResult(
108+
message="No documents provided for ingestion",
109+
context_variables=context_variables,
110+
)
111+
112+
# Validate document paths/URLs
113+
valid_documents = []
114+
for doc_path in documents_to_ingest:
115+
if isinstance(doc_path, str) and doc_path.strip():
116+
valid_documents.append(doc_path.strip())
117+
else:
118+
logger.warning(f"Invalid document path: {doc_path}")
119+
120+
if not valid_documents:
121+
return ReplyResult(
122+
message="No valid documents found for ingestion",
123+
context_variables=context_variables,
124+
)
125+
103126
# Safely handle context variable initialization
104127
if "DocumentsToIngest" not in context_variables:
105128
context_variables["DocumentsToIngest"] = []
@@ -111,7 +134,7 @@ async def ingest_documents(
111134
context_variables["QueriesToRun"] = []
112135

113136
# Add current batch to pending ingestions
114-
context_variables["DocumentsToIngest"].append(documents_to_ingest)
137+
context_variables["DocumentsToIngest"].append(valid_documents)
115138

116139
try:
117140
# Process documents concurrently using ThreadPoolExecutor
@@ -125,7 +148,7 @@ async def ingest_documents(
125148
self.rag_config,
126149
self.rag_engines,
127150
)
128-
for doc_path in documents_to_ingest
151+
for doc_path in valid_documents
129152
]
130153

131154
# Wait for all documents to be processed
@@ -151,7 +174,7 @@ async def ingest_documents(
151174
logger.info("=" * 80)
152175
logger.info("TOOL: ingest_documents (CONCURRENT)")
153176
logger.info("AGENT: TaskManagerAgent")
154-
logger.info(f"DOCUMENTS: {documents_to_ingest}")
177+
logger.info(f"DOCUMENTS: {valid_documents}")
155178
logger.info(f"SUCCESSFULLY INGESTED: {successfully_ingested}")
156179
logger.info("=" * 80)
157180

@@ -175,11 +198,11 @@ async def ingest_documents(
175198
logger.error("TOOL ERROR: ingest_documents (CONCURRENT)")
176199
logger.error("AGENT: TaskManagerAgent")
177200
logger.error(f"ERROR: {e}")
178-
logger.error(f"DOCUMENTS: {documents_to_ingest}")
201+
logger.error(f"DOCUMENTS: {valid_documents}")
179202
logger.error("=" * 80)
180203

181204
# Preserve failed documents for retry
182-
context_variables["DocumentsToIngest"] = [documents_to_ingest]
205+
context_variables["DocumentsToIngest"] = [valid_documents]
183206
return ReplyResult(
184207
message=f"Documents ingestion failed: {e}",
185208
context_variables=context_variables,
@@ -198,6 +221,11 @@ async def execute_query(queries_to_run: list[str], context_variables: ContextVar
198221
if not queries_to_run:
199222
return "No queries to run"
200223

224+
# Validate queries
225+
valid_queries = [q.strip() for q in queries_to_run if isinstance(q, str) and q.strip()]
226+
if not valid_queries:
227+
return "No valid queries provided"
228+
201229
# Safely handle context variable initialization
202230
if "QueriesToRun" not in context_variables:
203231
context_variables["QueriesToRun"] = []
@@ -209,7 +237,7 @@ async def execute_query(queries_to_run: list[str], context_variables: ContextVar
209237
context_variables["Citations"] = []
210238

211239
# Add current batch to pending queries
212-
context_variables["QueriesToRun"].append(queries_to_run)
240+
context_variables["QueriesToRun"].append(valid_queries)
213241

214242
try:
215243
# Clear temporary citations store before processing
@@ -219,7 +247,7 @@ async def execute_query(queries_to_run: list[str], context_variables: ContextVar
219247
loop = asyncio.get_event_loop()
220248
futures = [
221249
loop.run_in_executor(self.executor, execute_single_query, query, self.rag_config, self.rag_engines)
222-
for query in queries_to_run
250+
for query in valid_queries
223251
]
224252

225253
# Wait for all queries to be processed
@@ -252,15 +280,15 @@ async def execute_query(queries_to_run: list[str], context_variables: ContextVar
252280
logger.info("=" * 80)
253281
logger.info("TOOL: execute_query (CONCURRENT)")
254282
logger.info("AGENT: TaskManagerAgent")
255-
logger.info(f"QUERIES: {queries_to_run}")
283+
logger.info(f"QUERIES: {valid_queries}")
256284
logger.info("=" * 80)
257285

258286
# Update context variables
259287
context_variables["QueriesToRun"].pop(0) # Remove processed batch
260288
context_variables["CompletedTaskCount"] += 1
261289

262290
# Store query results with citations
263-
query_result = {"query": queries_to_run, "answer": answers, "citations": all_citations}
291+
query_result = {"query": valid_queries, "answer": answers, "citations": all_citations}
264292
context_variables["QueryResults"].append(query_result)
265293
# Clear temporary citations store after processing
266294
self._temp_citations_store = {}
@@ -271,13 +299,13 @@ async def execute_query(queries_to_run: list[str], context_variables: ContextVar
271299
)
272300

273301
except Exception as e:
274-
error_msg = f"Query failed for queries '{queries_to_run}': {str(e)}"
302+
error_msg = f"Query failed for queries '{valid_queries}': {str(e)}"
275303

276304
# Enhanced error logging
277305
logger.error("=" * 80)
278306
logger.error("TOOL ERROR: execute_query (CONCURRENT)")
279307
logger.error("AGENT: TaskManagerAgent")
280-
logger.error(f"QUERIES: {queries_to_run}")
308+
logger.error(f"QUERIES: {valid_queries}")
281309
logger.error(f"ERROR: {e}")
282310
logger.error("=" * 80)
283311

@@ -298,8 +326,11 @@ async def execute_query(queries_to_run: list[str], context_variables: ContextVar
298326

299327
def __del__(self) -> None:
300328
"""Clean up the ThreadPoolExecutor when the agent is destroyed."""
301-
if hasattr(self, "executor"):
302-
self.executor.shutdown(wait=True)
329+
if hasattr(self, "executor") and self.executor is not None:
330+
try:
331+
self.executor.shutdown(wait=False) # Don't block in destructor
332+
except Exception as e:
333+
logger.warning(f"Error shutting down executor: {e}")
303334

304335
def _create_rag_engines(self, collection_name: str | None = None) -> dict[str, Any]:
305336
"""Create RAG engines based on rag_config."""
@@ -335,3 +366,8 @@ def _create_neo4j_engine(self, config: dict[str, Any]) -> Any:
335366
except ImportError as e:
336367
logger.warning(f"Neo4j dependencies not available: {e}. Skipping graph engine.")
337368
return None
369+
370+
def _safe_context_update(self, context_variables: ContextVariables, key: str, value: Any) -> None:
371+
"""Thread-safe context variable update."""
372+
with self._context_lock:
373+
context_variables[key] = value

autogen/agents/experimental/document_agent/task_manager_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,12 @@ def extract_text_from_pdf(doc_path: str) -> list[dict[str, str]]:
2828
text = ""
2929
# Save the PDF to a temporary file
3030
with tempfile.TemporaryDirectory() as temp_dir:
31-
with open(temp_dir + "temp.pdf", "wb") as f:
31+
temp_pdf_path = Path(temp_dir) / "temp.pdf"
32+
with open(temp_pdf_path, "wb") as f:
3233
f.write(response.content)
3334

3435
# Open the PDF
35-
with fitz.open(temp_dir + "temp.pdf") as doc:
36+
with fitz.open(str(temp_pdf_path)) as doc:
3637
# Read and extract text from each page
3738
for page in doc:
3839
text += page.get_text()
@@ -72,6 +73,9 @@ def compress_and_save_text(text: str, input_path: str, parsed_docs_path: Path) -
7273
text_compressor = TextMessageCompressor(text_compressor=llm_lingua)
7374
compressed_text = text_compressor.apply_transform([{"content": text}])
7475

76+
if not compressed_text or not compressed_text[0].get("content"):
77+
raise ValueError("Text compression failed or returned empty result")
78+
7579
# Create a markdown file with the extracted text
7680
output_file = parsed_docs_path / f"{Path(input_path).stem}.md"
7781
parsed_docs_path.mkdir(parents=True, exist_ok=True)
@@ -138,7 +142,7 @@ def process_single_document(
138142

139143
if is_pdf:
140144
# Handle PDF with PyMuPDF
141-
print("PDF found using PyMuPDF")
145+
logger.info("PDF found, using PyMuPDF for extraction")
142146
if urllib3.util.url.parse_url(input_file_path).scheme:
143147
# Download the PDF
144148
response = requests.get(input_file_path)

0 commit comments

Comments
 (0)