From dbb802113d56c12141afe52670f5492531778173 Mon Sep 17 00:00:00 2001 From: zjwu0522 Date: Wed, 20 Aug 2025 08:50:51 +0000 Subject: [PATCH] fix: current run will return token usage and log if encounter streaming error --- src/agent.py | 445 ++++++++++++++++++++++++++++++--------- src/base/task_manager.py | 10 +- src/evaluator.py | 37 ++-- 3 files changed, 364 insertions(+), 128 deletions(-) diff --git a/src/agent.py b/src/agent.py index 5c6e49f5..996f2928 100644 --- a/src/agent.py +++ b/src/agent.py @@ -14,6 +14,7 @@ # Python stdlib import asyncio +import json import os import time from typing import Any, Dict, Callable @@ -43,9 +44,17 @@ # Initialize logger logger = get_logger(__name__) -import nest_asyncio -nest_asyncio.apply() + +def _apply_nest_asyncio(): + """Apply nest_asyncio to allow nested event loops.""" + import nest_asyncio + + nest_asyncio.apply() + + +# Apply nested asyncio support +_apply_nest_asyncio() class MCPAgent: @@ -122,7 +131,11 @@ def _create_model_provider(self) -> ModelProvider: client = AsyncOpenAI( base_url=self.base_url, api_key=self.api_key, - default_headers={ "App-Code": "LobeHub", 'HTTP-Referer': 'https://lobehub.com', 'X-Title': 'LobeHub' } + default_headers={ + "App-Code": "LobeHub", + "HTTP-Referer": "https://lobehub.com", + "X-Title": "LobeHub", + }, ) agent_model_name = self.model_name # Capture the model name from the agent @@ -152,11 +165,19 @@ async def _create_mcp_server(self) -> Any: cfg = self.service_config # shorthand # Services that use npx or pipx and need startup delay - NPX_BASED_SERVICES = ["notion", "filesystem", "playwright", "playwright_webarena"] + NPX_BASED_SERVICES = [ + "notion", + "filesystem", + "playwright", + "playwright_webarena", + ] PIPX_BASED_SERVICES = ["postgres"] - + # Add startup delay for npx-based and pipx-based services to ensure proper initialization - if self.mcp_service in NPX_BASED_SERVICES or self.mcp_service in PIPX_BASED_SERVICES: + if ( + self.mcp_service in NPX_BASED_SERVICES + or self.mcp_service in PIPX_BASED_SERVICES + ): logger.debug(f"Adding startup delay for service: {self.mcp_service}") await asyncio.sleep(5) @@ -317,19 +338,38 @@ async def _create_mcp_server(self) -> Any: else: raise ValueError(f"Unsupported MCP service: {self.mcp_service}") - async def _execute_with_streaming(self, instruction: str) -> Dict[str, Any]: + def _write_to_log_file(self, log_file_path: str, content: str): + """Write content to log file, creating directory if needed.""" + if log_file_path: + try: + import os + + os.makedirs(os.path.dirname(log_file_path), exist_ok=True) + with open(log_file_path, "a", encoding="utf-8") as f: + f.write(content) + except Exception as log_error: + logger.debug(f"Failed to write to log file: {log_error}") + + async def _execute_with_streaming( + self, instruction: str, tool_call_log_file: str = None + ) -> Dict[str, Any]: """ Execute instruction with agent using streaming response. Args: instruction: The instruction/prompt to execute - (Service configuration is taken from self.service_config) + tool_call_log_file: Optional path to log tool calls (Service configuration is taken from self.service_config) Returns: Dictionary containing execution results """ start_time = time.time() + # Initialize partial results to preserve even on failure + partial_output = [] + partial_token_usage = {} + partial_turn_count = 0 + try: # Refresh service configuration before each execution self._refresh_service_config() @@ -366,55 +406,208 @@ async def _execute_with_streaming(self, instruction: str) -> Dict[str, Any]: line_prefix = "| " at_line_start = True last_event_type = None # Track the previous event type - async for event in result.stream_events(): - event_count += 1 - logger.debug(f"Event {event_count}: {event}") - - if hasattr(event, "type"): - logger.debug(f"Event type: {event.type}") - - if event.type == "raw_response_event": - if hasattr(event, "data") and isinstance( - event.data, ResponseTextDeltaEvent - ): - delta_text = event.data.delta or "" - # Stream with line prefix, handling chunked newlines - for chunk in delta_text.splitlines(True): # keepends=True - if at_line_start: - print(line_prefix, end="", flush=True) - print(chunk, end="", flush=True) - at_line_start = chunk.endswith("\n") - - last_event_type = "text_output" - - elif event.type == "run_item_stream_event": - if ( - hasattr(event, "item") - and getattr(event.item, "type", "") == "tool_call_item" - ): - if last_event_type == "text_output": - # Add newline if text wasn't already on a new line - if not at_line_start: - print("\n", end="", flush=True) - at_line_start = True - - tool_name = getattr( - getattr(event.item, "raw_item", None), - "name", - "Unknown", - ) - - arguments = getattr(getattr(event.item, "raw_item", None), 'arguments', None) - - if isinstance(arguments, str): - display_arguments = arguments[:140] + "..." if len(arguments) > 140 else arguments - else: - display_arguments = arguments - logger.info( - f"| \033[1m{tool_name}\033[0m \033[2;37m{display_arguments}\033[0m" - ) - - last_event_type = "tool_call" + + # Track if max_turns was exceeded + max_turns_exceeded = False + + try: + async for event in result.stream_events(): + event_count += 1 + logger.debug(f"Event {event_count}: {event}") + + if hasattr(event, "type"): + logger.debug(f"Event type: {event.type}") + + if event.type == "raw_response_event": + if hasattr(event, "data") and isinstance( + event.data, ResponseTextDeltaEvent + ): + delta_text = event.data.delta or "" + # Stream with line prefix, handling chunked newlines + for chunk in delta_text.splitlines( + True + ): # keepends=True + if at_line_start: + print(line_prefix, end="", flush=True) + print(chunk, end="", flush=True) + at_line_start = chunk.endswith("\n") + + # Also log text output to file (preserve original formatting) + if delta_text.strip(): # Only log non-empty content + self._write_to_log_file( + tool_call_log_file, delta_text + ) + + last_event_type = "text_output" + + elif event.type == "run_item_stream_event": + if ( + hasattr(event, "item") + and getattr(event.item, "type", "") + == "tool_call_item" + ): + if last_event_type == "text_output": + # Add newline if text wasn't already on a new line + if not at_line_start: + print("\n", end="", flush=True) + at_line_start = True + + tool_name = getattr( + getattr(event.item, "raw_item", None), + "name", + "Unknown", + ) + + arguments = getattr( + getattr(event.item, "raw_item", None), + "arguments", + None, + ) + + if isinstance(arguments, str): + display_arguments = ( + arguments[:140] + "..." + if len(arguments) > 140 + else arguments + ) + else: + # Convert non-string arguments to single-line JSON + try: + args_str = json.dumps( + arguments, separators=(",", ": ") + ) + display_arguments = ( + args_str[:140] + "..." + if len(args_str) > 140 + else args_str + ) + except Exception: + display_arguments = str(arguments)[:140] + logger.info( + f"| \033[1m{tool_name}\033[0m \033[2;37m{display_arguments}\033[0m" + ) + + # Also log tool call to log file (ensure proper line breaks) + args_str = ( + arguments + if isinstance(arguments, str) + else json.dumps( + arguments, separators=(",", ": ") + ) + ) + # Add newline before tool call if previous was text output + prefix = ( + "\n" if last_event_type == "text_output" else "" + ) + self._write_to_log_file( + tool_call_log_file, + f"{prefix}| {tool_name} {args_str}\n", + ) + + last_event_type = "tool_call" + + except Exception as stream_error: + error_msg = f"Error during streaming: {stream_error}" + logger.error(error_msg, exc_info=True) + # Also log error to file (ensure proper line break) + self._write_to_log_file( + tool_call_log_file, f"\n| ERROR: {error_msg}\n" + ) + + # Try to extract whatever conversation output we can get from the result + try: + if hasattr(result, "to_input_list"): + partial_output = result.to_input_list() + logger.debug( + f"Extracted partial output during stream error: {len(partial_output) if partial_output else 0} messages" + ) + except Exception as extract_error: + logger.debug( + f"Failed to extract output during stream error: {extract_error}" + ) + # Keep the existing partial_output + + # Try to extract token usage from any available raw responses + try: + if hasattr(result, "raw_responses") and result.raw_responses: + total_input_tokens = 0 + total_output_tokens = 0 + total_tokens = 0 + for response in result.raw_responses: + if hasattr(response, "usage") and response.usage: + total_input_tokens += ( + response.usage.input_tokens or 0 + ) + total_output_tokens += ( + response.usage.output_tokens or 0 + ) + total_tokens += response.usage.total_tokens or 0 + + partial_token_usage = { + "input_tokens": total_input_tokens, + "output_tokens": total_output_tokens, + "total_tokens": total_tokens, + } + logger.debug( + f"Extracted partial token usage during stream error: {partial_token_usage}" + ) + + # Try to extract turn count + if hasattr(result, "current_turn"): + partial_turn_count = max(result.current_turn - 1, 0) + logger.debug( + f"Extracted partial turn count during stream error: {partial_turn_count}" + ) + except Exception as usage_error: + logger.debug( + f"Failed to extract token usage during stream error: {usage_error}" + ) + # Keep the existing partial values + + # If this is a critical streaming error, we should fail the execution + # rather than continuing and potentially returning success=True + execution_time = time.time() - start_time + self._usage_stats["failed_executions"] += 1 + self._usage_stats["total_execution_time"] += execution_time + + # Update usage stats with any partial token usage we collected + if partial_token_usage: + self._usage_stats["total_input_tokens"] += ( + partial_token_usage.get("input_tokens", 0) + ) + self._usage_stats["total_output_tokens"] += ( + partial_token_usage.get("output_tokens", 0) + ) + self._usage_stats["total_tokens"] += partial_token_usage.get( + "total_tokens", 0 + ) + self._usage_stats["total_turns"] += partial_turn_count + + return { + "success": False, + "output": partial_output if partial_output else [], + "token_usage": partial_token_usage + if partial_token_usage + else {}, + "turn_count": partial_turn_count, + "execution_time": execution_time, + "error": str(stream_error), + } + + # Debug: Log the result attributes + logger.debug(f"Result attributes: {dir(result)}") + logger.debug(f"Has raw_responses: {hasattr(result, 'raw_responses')}") + logger.debug(f"Has current_turn: {hasattr(result, 'current_turn')}") + if hasattr(result, "raw_responses"): + logger.debug( + f"Raw responses count: {len(result.raw_responses) if result.raw_responses else 0}" + ) + + # Check if max_turns was exceeded + # The result object may have completed normally but hit the turn limit + if hasattr(result, "current_turn") and result.current_turn > 100: + max_turns_exceeded = True + logger.warning(f"| Max turns ({result.current_turn - 1}) reached") # Extract token usage from raw responses token_usage = {} @@ -433,9 +626,25 @@ async def _execute_with_streaming(self, instruction: str) -> Dict[str, Any]: "output_tokens": total_output_tokens, "total_tokens": total_tokens, } + # Update partial token usage as we go + partial_token_usage = token_usage + else: + # If raw_responses is empty, try to extract from individual responses + logger.debug( + "No raw_responses found, checking for other response data" + ) # Extract turn count turn_count = getattr(result, "current_turn", None) + if turn_count: + partial_turn_count = turn_count + + # Try to extract partial conversation output in case of later failure + try: + partial_output = result.to_input_list() + except Exception as e: + logger.debug(f"Failed to extract conversation output: {e}") + # Keep whatever partial output we had before # Pretty usage block (prefixed lines) if token_usage: @@ -450,16 +659,38 @@ async def _execute_with_streaming(self, instruction: str) -> Dict[str, Any]: f"| Total: {total_tokens:,} | Input: {total_input_tokens:,} | Output: {total_output_tokens:,}", ] if turn_count is not None: - lines.append("| ────────────────────────────────────────────────") + lines.append( + "| ────────────────────────────────────────────────" + ) lines.append(f"| \033[1mTurns\033[0m: {turn_count}") - lines.append("| ────────────────────────────────────────────────") + lines.append( + "| ────────────────────────────────────────────────" + ) logger.info("\n".join(lines)) # Extract conversation output - conversation_output = result.to_input_list() + conversation_output = [] + try: + conversation_output = result.to_input_list() + except Exception as e: + logger.debug(f"Failed to extract final conversation output: {e}") + conversation_output = partial_output if partial_output else [] + + # Update partial results with final values + partial_output = conversation_output + partial_token_usage = ( + token_usage if token_usage else partial_token_usage + ) + partial_turn_count = turn_count if turn_count else partial_turn_count execution_time = time.time() - start_time + # Check if we hit max_turns limit and adjust turn count + if max_turns_exceeded and turn_count: + # When max_turns is exceeded, SDK reports the turn it tried to start + # but didn't execute, so subtract 1 for actual completed turns + turn_count = turn_count - 1 + # Update usage statistics self._usage_stats["total_input_tokens"] += token_usage.get( "input_tokens", 0 @@ -470,6 +701,19 @@ async def _execute_with_streaming(self, instruction: str) -> Dict[str, Any]: self._usage_stats["total_tokens"] += token_usage.get("total_tokens", 0) self._usage_stats["total_turns"] += turn_count or 0 self._usage_stats["total_execution_time"] += execution_time + + # Check if we hit max_turns limit and should report as error + if max_turns_exceeded: + self._usage_stats["failed_executions"] += 1 + return { + "success": False, + "output": conversation_output, + "token_usage": token_usage if token_usage else {}, + "turn_count": turn_count if turn_count else 0, + "execution_time": execution_time, + "error": f"Max turns ({turn_count if turn_count else 0}) exceeded", + } + self._usage_stats["successful_executions"] += 1 return { @@ -486,23 +730,43 @@ async def _execute_with_streaming(self, instruction: str) -> Dict[str, Any]: self._usage_stats["failed_executions"] += 1 self._usage_stats["total_execution_time"] += execution_time - logger.error(f"| Agent execution failed: {e}", exc_info=True) + # Update usage stats with any partial token usage we collected + if partial_token_usage: + self._usage_stats["total_input_tokens"] += partial_token_usage.get( + "input_tokens", 0 + ) + self._usage_stats["total_output_tokens"] += partial_token_usage.get( + "output_tokens", 0 + ) + self._usage_stats["total_tokens"] += partial_token_usage.get( + "total_tokens", 0 + ) + self._usage_stats["total_turns"] += partial_turn_count + + error_msg = f"| Agent execution failed: {e}" + logger.error(error_msg, exc_info=True) + # Also log error to file (ensure proper line break) + self._write_to_log_file(tool_call_log_file, f"\n| ERROR: {error_msg}\n") return { "success": False, - "output": "", - "token_usage": {}, - "turn_count": 0, + "output": partial_output + if partial_output + else [], # Preserve partial output + "token_usage": partial_token_usage if partial_token_usage else {}, + "turn_count": partial_turn_count, "execution_time": execution_time, "error": str(e), } - async def execute(self, instruction: str) -> Dict[str, Any]: + async def execute( + self, instruction: str, tool_call_log_file: str = None + ) -> Dict[str, Any]: """ - Execute instruction with automatic retries on transient errors. + Execute instruction without retries. Args: instruction: The instruction/prompt to execute - (Service configuration is taken from self.service_config) + tool_call_log_file: Optional path to log tool calls (Service configuration is taken from self.service_config) Returns: Dictionary containing: @@ -514,57 +778,28 @@ async def execute(self, instruction: str) -> Dict[str, Any]: - error: error message if failed """ - for attempt in range(1, self.max_retries + 1): - # Merge default config with any overrides supplied at call time - - result = await asyncio.wait_for( - self._execute_with_streaming(instruction), timeout=self.timeout - ) - - # Success - return immediately - if result["success"]: - return result - - # Standardize error message - from src.errors import ( - standardize_error_message, - is_retryable_error, - get_retry_delay, - ) - - error_msg = standardize_error_message( - result["error"] or "Unknown error", mcp_service=self.mcp_service - ) - result["error"] = error_msg - - if is_retryable_error(result["error"]) and attempt < self.max_retries: - wait_seconds = get_retry_delay(attempt) - logger.warning( - f"| [Retry] Attempt {attempt}/{self.max_retries} failed. " - f"| Waiting {wait_seconds}s before retrying: {error_msg}" - ) - await asyncio.sleep(wait_seconds) - continue # Retry - - # Non-transient error or out of retry attempts - return last result - return result + result = await asyncio.wait_for( + self._execute_with_streaming(instruction, tool_call_log_file), + timeout=self.timeout, + ) - # Should never reach here, but return the last result as fallback return result - def execute_sync(self, instruction: str) -> Dict[str, Any]: + def execute_sync( + self, instruction: str, tool_call_log_file: str = None + ) -> Dict[str, Any]: """ Synchronous wrapper for execute method. Args: instruction: The instruction/prompt to execute - (Service configuration is taken from self.service_config) + tool_call_log_file: Optional path to log tool calls (Service configuration is taken from self.service_config) Returns: Dictionary containing execution results """ try: - return asyncio.run(self.execute(instruction)) + return asyncio.run(self.execute(instruction, tool_call_log_file)) except asyncio.TimeoutError: self._usage_stats["failed_executions"] += 1 return { diff --git a/src/base/task_manager.py b/src/base/task_manager.py index 7ce8f44e..eacff0ad 100644 --- a/src/base/task_manager.py +++ b/src/base/task_manager.py @@ -140,7 +140,7 @@ def filter_tasks(self, task_filter: str) -> List[BaseTask]: if "/" in task_filter: try: category, task_part = task_filter.split("/", 1) - + # First try to match by task_id (could be numeric or string) for task in all_tasks: if task.category == category: @@ -148,7 +148,10 @@ def filter_tasks(self, task_filter: str) -> List[BaseTask]: if str(task.task_id) == task_part: return [task] # Also check if it's a task_N format and matches - if task_part.startswith("task_") and str(task.task_id) == task_part.split("_", 1)[1]: + if ( + task_part.startswith("task_") + and str(task.task_id) == task_part.split("_", 1)[1] + ): return [task] except (ValueError, IndexError): pass @@ -276,6 +279,9 @@ def execute_task(self, task: BaseTask, agent_result: Dict[str, Any]) -> TaskResu error_message=str(e), category=task.category, task_id=task.task_id, + model_output=agent_result.get("output", ""), + token_usage=agent_result.get("token_usage", {}), + turn_count=agent_result.get("turn_count", 0), ) def run_verification(self, task: BaseTask) -> subprocess.CompletedProcess: diff --git a/src/evaluator.py b/src/evaluator.py index ad27f20d..db1fa97b 100644 --- a/src/evaluator.py +++ b/src/evaluator.py @@ -178,7 +178,7 @@ def _run_single_task(self, task) -> TaskResult: """ # Track overall task start time task_start_time = time.time() - + # Stage 1: Set up the initial state for the task setup_start_time = time.time() logger.info( @@ -200,9 +200,7 @@ def _run_single_task(self, task) -> TaskResult: task_execution_time=task_total_time, ) display_time = self._format_duration(setup_time) - logger.info( - f"└─ Completed in {display_time}\n" - ) + logger.info(f"└─ Completed in {display_time}\n") # Stage 2: Execute the task using the agent logger.info( @@ -214,14 +212,19 @@ def _run_single_task(self, task) -> TaskResult: # Get task instruction from task manager task_instruction = self.task_manager.get_task_instruction(task) + # ---------- Prepare task_output_dir and tool call log file ---------- + task_output_dir = self._get_task_output_dir(task) + task_output_dir.mkdir(parents=True, exist_ok=True) + execution_log_path = task_output_dir / "execution.log" + # Execute with agent - agent_result = self.agent.execute_sync(task_instruction) + agent_result = self.agent.execute_sync( + task_instruction, str(execution_log_path) + ) agent_execution_time = time.time() - execute_start_time # ---------- Write messages.json to task_output_dir ---------- - task_output_dir = self._get_task_output_dir(task) - task_output_dir.mkdir(parents=True, exist_ok=True) messages_path = task_output_dir / "messages.json" self.results_reporter.save_messages_json( agent_result.get("output", []), messages_path @@ -229,9 +232,7 @@ def _run_single_task(self, task) -> TaskResult: # Set service-specific environment variables for verification scripts self.state_manager.set_verification_environment(str(messages_path)) - logger.info( - f"└─ Completed in {self._format_duration(agent_execution_time)}\n" - ) + logger.info(f"└─ Completed in {self._format_duration(agent_execution_time)}\n") # Stage 3: Verify logger.info( @@ -243,13 +244,11 @@ def _run_single_task(self, task) -> TaskResult: finally: # Clean up environment variables import os + os.environ.pop("MCP_MESSAGES", None) os.environ.pop("MCP_GITHUB_TOKEN", None) verify_time = time.time() - verify_start_time - logger.info( - f"└─ Completed in {self._format_duration(verify_time)}\n" - ) - + logger.info(f"└─ Completed in {self._format_duration(verify_time)}\n") # Stage 4: Clean up logger.info( @@ -258,13 +257,11 @@ def _run_single_task(self, task) -> TaskResult: cleanup_start_time = time.time() self.state_manager.clean_up(task) cleanup_time = time.time() - cleanup_start_time - logger.info( - f"└─ Completed in {self._format_duration(cleanup_time)}\n" - ) + logger.info(f"└─ Completed in {self._format_duration(cleanup_time)}\n") # Calculate total task execution time task_total_time = time.time() - task_start_time - + # Add timing information to the result result.agent_execution_time = agent_execution_time result.task_execution_time = task_total_time @@ -408,8 +405,6 @@ def _matches_filter(tr: TaskResult, flt: str) -> bool: logger.info( f"✓ Tasks passed: {aggregated_report.successful_tasks}/{aggregated_report.total_tasks} ({aggregated_report.success_rate:.1f}%)" ) - logger.info( - f"⏱ Total time: {aggregated_report.total_task_execution_time:.1f}s" - ) + logger.info(f"⏱ Total time: {aggregated_report.total_task_execution_time:.1f}s") return aggregated_report