From ae8531bfe97e6f729d0753bf3d4d611a5afe2465 Mon Sep 17 00:00:00 2001 From: newcat22 <2289063896@qq.com> Date: Sun, 3 May 2026 19:09:03 +0800 Subject: [PATCH 1/2] feat(agent): add tool confirmation before execution Add user confirmation mechanism for tool calls with prefix-based command matching for bash and per-tool-name matching for others. - Add ToolConfirmationResult enum and get_tool_confirmation() abstract method - Add ToolConfirmationConfig to enable/configure per-tool confirmation - Intercept tool calls in BaseAgent._tool_call_handler before execution - For bash: "always approve" uses command prefix matching (e.g. approving "pip install requests" auto-approves all "pip install *" commands) - For non-bash tools: "always approve" remembers the tool name - Add --confirm-tools CLI flag to run and interactive commands - Add session-scoped allowlist (resets on exit/new task) - Add 17 unit tests for the confirmation feature Co-Authored-By: Claude Opus 4.7 --- tests/agent/test_tool_confirmation.py | 307 +++++++++++++++++++++++++ trae_agent/agent/base_agent.py | 113 ++++++++- trae_agent/agent/trae_agent.py | 1 + trae_agent/cli.py | 32 +++ trae_agent/utils/cli/__init__.py | 3 +- trae_agent/utils/cli/cli_console.py | 21 ++ trae_agent/utils/cli/rich_console.py | 36 +++ trae_agent/utils/cli/simple_console.py | 37 +++ trae_agent/utils/config.py | 19 ++ trae_config.yaml.example | 6 + 10 files changed, 566 insertions(+), 9 deletions(-) create mode 100644 tests/agent/test_tool_confirmation.py diff --git a/tests/agent/test_tool_confirmation.py b/tests/agent/test_tool_confirmation.py new file mode 100644 index 000000000..164a868db --- /dev/null +++ b/tests/agent/test_tool_confirmation.py @@ -0,0 +1,307 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import unittest +from unittest.mock import MagicMock, patch + +from trae_agent.agent.trae_agent import TraeAgent +from trae_agent.tools.base import ToolCall +from trae_agent.utils.cli.cli_console import ToolConfirmationResult +from trae_agent.utils.config import Config, ToolConfirmationConfig +from trae_agent.utils.legacy_config import LegacyConfig + + +class TestToolConfirmationConfig(unittest.TestCase): + def test_default_disabled(self): + config = ToolConfirmationConfig() + self.assertFalse(config.enabled) + self.assertIsNone(config.tools_requiring_confirmation) + + def test_custom_config(self): + config = ToolConfirmationConfig( + enabled=True, + tools_requiring_confirmation=["bash"], + ) + self.assertTrue(config.enabled) + self.assertEqual(config.tools_requiring_confirmation, ["bash"]) + + +class TestShouldConfirmTool(unittest.TestCase): + def setUp(self): + test_config = { + "default_provider": "anthropic", + "max_steps": 20, + "model_providers": { + "anthropic": { + "model": "claude-sonnet-4-20250514", + "api_key": "test-dummy-api-key", + "max_tokens": 4096, + "temperature": 0.5, + "top_p": 1, + "top_k": 0, + "parallel_tool_calls": False, + "max_retries": 10, + } + }, + } + self.config = Config.create_from_legacy_config(legacy_config=LegacyConfig(test_config)) + self.llm_client_patcher = patch("trae_agent.agent.base_agent.LLMClient") + mock_llm_client = self.llm_client_patcher.start() + mock_llm_client.return_value.client = MagicMock() + if self.config.trae_agent: + self.agent = TraeAgent(self.config.trae_agent) + else: + self.fail("trae_agent config is None") + + def tearDown(self): + self.llm_client_patcher.stop() + + def test_disabled_by_default(self): + self.assertFalse(self.agent._should_confirm_tool("bash")) + + def test_all_tools_when_none_list(self): + self.agent._tool_confirmation_config = ToolConfirmationConfig( + enabled=True, tools_requiring_confirmation=None + ) + self.assertTrue(self.agent._should_confirm_tool("bash")) + self.assertTrue(self.agent._should_confirm_tool("sequentialthinking")) + + def test_specific_tools_only(self): + self.agent._tool_confirmation_config = ToolConfirmationConfig( + enabled=True, tools_requiring_confirmation=["bash"] + ) + self.assertTrue(self.agent._should_confirm_tool("bash")) + self.assertFalse(self.agent._should_confirm_tool("sequentialthinking")) + + def test_name_normalization(self): + self.agent._tool_confirmation_config = ToolConfirmationConfig( + enabled=True, tools_requiring_confirmation=["str_replace_based_edit_tool"] + ) + # Normalized: "strreplacebasededittool" == "strreplacebasededittool" + self.assertTrue(self.agent._should_confirm_tool("str_replace_based_edit_tool")) + + +class TestIsToolCallAllowed(unittest.TestCase): + def setUp(self): + test_config = { + "default_provider": "anthropic", + "max_steps": 20, + "model_providers": { + "anthropic": { + "model": "claude-sonnet-4-20250514", + "api_key": "test-dummy-api-key", + "max_tokens": 4096, + "temperature": 0.5, + "top_p": 1, + "top_k": 0, + "parallel_tool_calls": False, + "max_retries": 10, + } + }, + } + self.config = Config.create_from_legacy_config(legacy_config=LegacyConfig(test_config)) + self.llm_client_patcher = patch("trae_agent.agent.base_agent.LLMClient") + mock_llm_client = self.llm_client_patcher.start() + mock_llm_client.return_value.client = MagicMock() + if self.config.trae_agent: + self.agent = TraeAgent(self.config.trae_agent) + else: + self.fail("trae_agent config is None") + + def tearDown(self): + self.llm_client_patcher.stop() + + def test_not_allowed_by_default(self): + tool_call = ToolCall(name="bash", call_id="1", arguments={"command": "ls"}) + self.assertFalse(self.agent._is_tool_call_allowed(tool_call)) + + def test_approved_all_allows_everything(self): + self.agent._tool_confirmation_approved_all = True + tool_call = ToolCall(name="bash", call_id="1", arguments={"command": "ls"}) + self.assertTrue(self.agent._is_tool_call_allowed(tool_call)) + + def test_bash_prefix_matching(self): + self.agent._allowed_command_prefixes.append("pip install") + matching = ToolCall(name="bash", call_id="1", arguments={"command": "pip install requests"}) + non_matching = ToolCall(name="bash", call_id="2", arguments={"command": "pip uninstall requests"}) + self.assertTrue(self.agent._is_tool_call_allowed(matching)) + self.assertFalse(self.agent._is_tool_call_allowed(non_matching)) + + def test_non_bash_tool_name_matching(self): + self.agent._allowed_tool_names.add("strreplacebasededittool") + tool_call = ToolCall( + name="str_replace_based_edit_tool", call_id="1", arguments={} + ) + self.assertTrue(self.agent._is_tool_call_allowed(tool_call)) + + +class TestAddAllowedPattern(unittest.TestCase): + def setUp(self): + test_config = { + "default_provider": "anthropic", + "max_steps": 20, + "model_providers": { + "anthropic": { + "model": "claude-sonnet-4-20250514", + "api_key": "test-dummy-api-key", + "max_tokens": 4096, + "temperature": 0.5, + "top_p": 1, + "top_k": 0, + "parallel_tool_calls": False, + "max_retries": 10, + } + }, + } + self.config = Config.create_from_legacy_config(legacy_config=LegacyConfig(test_config)) + self.llm_client_patcher = patch("trae_agent.agent.base_agent.LLMClient") + mock_llm_client = self.llm_client_patcher.start() + mock_llm_client.return_value.client = MagicMock() + if self.config.trae_agent: + self.agent = TraeAgent(self.config.trae_agent) + else: + self.fail("trae_agent config is None") + + def tearDown(self): + self.llm_client_patcher.stop() + + def test_bash_adds_command_prefix(self): + tool_call = ToolCall(name="bash", call_id="1", arguments={"command": "pip install requests"}) + self.agent._add_allowed_pattern(tool_call) + self.assertIn("pip install", self.agent._allowed_command_prefixes) + + def test_bash_single_token_command(self): + tool_call = ToolCall(name="bash", call_id="1", arguments={"command": "ls"}) + self.agent._add_allowed_pattern(tool_call) + self.assertIn("ls", self.agent._allowed_command_prefixes) + + def test_non_bash_adds_tool_name(self): + tool_call = ToolCall( + name="str_replace_based_edit_tool", call_id="1", arguments={} + ) + self.agent._add_allowed_pattern(tool_call) + self.assertIn("strreplacebasededittool", self.agent._allowed_tool_names) + + +class TestResetToolConfirmationState(unittest.TestCase): + def setUp(self): + test_config = { + "default_provider": "anthropic", + "max_steps": 20, + "model_providers": { + "anthropic": { + "model": "claude-sonnet-4-20250514", + "api_key": "test-dummy-api-key", + "max_tokens": 4096, + "temperature": 0.5, + "top_p": 1, + "top_k": 0, + "parallel_tool_calls": False, + "max_retries": 10, + } + }, + } + self.config = Config.create_from_legacy_config(legacy_config=LegacyConfig(test_config)) + self.llm_client_patcher = patch("trae_agent.agent.base_agent.LLMClient") + mock_llm_client = self.llm_client_patcher.start() + mock_llm_client.return_value.client = MagicMock() + if self.config.trae_agent: + self.agent = TraeAgent(self.config.trae_agent) + else: + self.fail("trae_agent config is None") + + def tearDown(self): + self.llm_client_patcher.stop() + + def test_reset_clears_state(self): + self.agent._tool_confirmation_approved_all = True + self.agent._allowed_command_prefixes.append("pip install") + self.agent._allowed_tool_names.add("bash") + + self.agent.reset_tool_confirmation_state() + + self.assertFalse(self.agent._tool_confirmation_approved_all) + self.assertEqual(len(self.agent._allowed_command_prefixes), 0) + self.assertEqual(len(self.agent._allowed_tool_names), 0) + + def test_new_task_resets_confirmation_state(self): + self.agent._tool_confirmation_approved_all = True + self.agent._allowed_command_prefixes.append("git commit") + + self.agent.new_task( + "test task", + extra_args={"project_path": "/test", "issue": "test"}, + ) + + self.assertFalse(self.agent._tool_confirmation_approved_all) + self.assertEqual(len(self.agent._allowed_command_prefixes), 0) + + +class TestToolCallHandlerConfirmation(unittest.TestCase): + def setUp(self): + test_config = { + "default_provider": "anthropic", + "max_steps": 20, + "model_providers": { + "anthropic": { + "model": "claude-sonnet-4-20250514", + "api_key": "test-dummy-api-key", + "max_tokens": 4096, + "temperature": 0.5, + "top_p": 1, + "top_k": 0, + "parallel_tool_calls": False, + "max_retries": 10, + } + }, + } + self.config = Config.create_from_legacy_config(legacy_config=LegacyConfig(test_config)) + self.llm_client_patcher = patch("trae_agent.agent.base_agent.LLMClient") + mock_llm_client = self.llm_client_patcher.start() + mock_llm_client.return_value.client = MagicMock() + if self.config.trae_agent: + self.agent = TraeAgent(self.config.trae_agent) + else: + self.fail("trae_agent config is None") + + def tearDown(self): + self.llm_client_patcher.stop() + + def test_rejected_tool_returns_error_result(self): + """When user rejects a tool call, it should return a ToolResult with success=False.""" + from trae_agent.agent.agent_basics import AgentStep, AgentStepState + + self.agent._tool_confirmation_config = ToolConfirmationConfig( + enabled=True, tools_requiring_confirmation=["bash"] + ) + mock_console = MagicMock() + mock_console.get_tool_confirmation.return_value = ToolConfirmationResult.REJECT + self.agent._cli_console = mock_console + + tool_call = ToolCall(name="bash", call_id="1", arguments={"command": "rm -rf /"}) + + import asyncio + + step = AgentStep(step_number=1, state=AgentStepState.THINKING) + messages = asyncio.get_event_loop().run_until_complete( + self.agent._tool_call_handler([tool_call], step) + ) + + # Should have one message with the rejected tool result + self.assertEqual(len(messages), 1) + self.assertFalse(step.tool_results[0].success) + self.assertIn("rejected", step.tool_results[0].error) + + def test_no_console_skips_confirmation(self): + """When no console is set, confirmation should be skipped even when enabled.""" + self.agent._tool_confirmation_config = ToolConfirmationConfig( + enabled=True, tools_requiring_confirmation=None + ) + self.agent._cli_console = None + + # Should not raise - just passes through + self.assertIsNone(self.agent._cli_console) + + +if __name__ == "__main__": + unittest.main() diff --git a/trae_agent/agent/base_agent.py b/trae_agent/agent/base_agent.py index 01d4fde4c..056fb7f1b 100644 --- a/trae_agent/agent/base_agent.py +++ b/trae_agent/agent/base_agent.py @@ -15,7 +15,8 @@ from trae_agent.tools.ckg.ckg_database import clear_older_ckg from trae_agent.tools.docker_tool_executor import DockerToolExecutor from trae_agent.utils.cli import CLIConsole -from trae_agent.utils.config import AgentConfig, ModelConfig +from trae_agent.utils.cli.cli_console import ToolConfirmationResult +from trae_agent.utils.config import AgentConfig, ModelConfig, ToolConfirmationConfig from trae_agent.utils.llm_clients.llm_basics import LLMMessage, LLMResponse from trae_agent.utils.llm_clients.llm_client import LLMClient from trae_agent.utils.trajectory_recorder import TrajectoryRecorder @@ -74,6 +75,12 @@ def __init__( self._cli_console: CLIConsole | None = None + # Tool confirmation state + self._tool_confirmation_config: ToolConfirmationConfig = agent_config.tool_confirmation + self._tool_confirmation_approved_all: bool = False + self._allowed_command_prefixes: list[str] = [] # For bash prefix matching + self._allowed_tool_names: set[str] = set() # For non-bash tool name matching + # Trajectory recorder self._trajectory_recorder: TrajectoryRecorder | None = None @@ -328,18 +335,58 @@ async def _tool_call_handler( step.tool_calls = tool_calls self._update_cli_console(step) - if self._model_config.parallel_tool_calls: - tool_results = await self._tool_caller.parallel_tool_call(tool_calls) + # Tool confirmation logic + approved_calls: list[ToolCall] = [] + rejected_results: list[ToolResult] = [] + + if ( + self._tool_confirmation_config.enabled + and not self._tool_confirmation_approved_all + and self._cli_console is not None + ): + for tool_call in tool_calls: + if self._is_tool_call_allowed(tool_call): + approved_calls.append(tool_call) + elif self._should_confirm_tool(tool_call.name): + confirmation = self._cli_console.get_tool_confirmation(tool_call) + if confirmation == ToolConfirmationResult.APPROVE: + approved_calls.append(tool_call) + elif confirmation == ToolConfirmationResult.APPROVE_ALL: + self._add_allowed_pattern(tool_call) + approved_calls.append(tool_call) + else: # REJECT + rejected_results.append( + ToolResult( + call_id=tool_call.call_id, + name=tool_call.name, + success=False, + error=f"Tool call '{tool_call.name}' was rejected by the user.", + id=tool_call.id, + ) + ) + else: + approved_calls.append(tool_call) else: - tool_results = await self._tool_caller.sequential_tool_call(tool_calls) - step.tool_results = tool_results + approved_calls = list(tool_calls) + + # Execute approved tool calls + if approved_calls: + if self._model_config.parallel_tool_calls: + tool_results = await self._tool_caller.parallel_tool_call(approved_calls) + else: + tool_results = await self._tool_caller.sequential_tool_call(approved_calls) + else: + tool_results = [] + + all_results = tool_results + rejected_results + step.tool_results = all_results self._update_cli_console(step) - for tool_result in tool_results: - # Add tool result to conversation + + for tool_result in all_results: message = LLMMessage(role="user", tool_result=tool_result) messages.append(message) - reflection = self.reflect_on_result(tool_results) + reflection = self.reflect_on_result(all_results) if reflection: step.state = AgentStepState.REFLECTING step.reflection = reflection @@ -350,3 +397,53 @@ async def _tool_call_handler( messages.append(LLMMessage(role="assistant", content=reflection)) return messages + + def _should_confirm_tool(self, tool_name: str) -> bool: + """Check whether a tool with the given name requires user confirmation.""" + config = self._tool_confirmation_config + if not config.enabled: + return False + required_list = config.tools_requiring_confirmation + if required_list is None: + return True + normalized_name = tool_name.lower().replace("_", "") + return any( + normalized_name == required.lower().replace("_", "") for required in required_list + ) + + def _is_tool_call_allowed(self, tool_call: ToolCall) -> bool: + """Check if a tool call matches a previously approved pattern.""" + if self._tool_confirmation_approved_all: + return True + + tool_name = tool_call.name.lower().replace("_", "") + + # For bash: check command prefix matching + if tool_name == "bash": + command = str(tool_call.arguments.get("command", "")) + for prefix in self._allowed_command_prefixes: + if command.startswith(prefix): + return True + + # For non-bash tools: check tool name matching + return tool_name in self._allowed_tool_names + + def _add_allowed_pattern(self, tool_call: ToolCall) -> None: + """Add an allowed pattern based on the tool call (APPROVE_ALL result).""" + tool_name = tool_call.name.lower().replace("_", "") + + # For bash: add command prefix (first two tokens) + if tool_name == "bash": + command = str(tool_call.arguments.get("command", "")) + tokens = command.split() + prefix = " ".join(tokens[:2]) if len(tokens) >= 2 else command + self._allowed_command_prefixes.append(prefix) + else: + # For non-bash tools: allow by tool name + self._allowed_tool_names.add(tool_name) + + def reset_tool_confirmation_state(self) -> None: + """Reset the tool confirmation state for a new task.""" + self._tool_confirmation_approved_all = False + self._allowed_command_prefixes.clear() + self._allowed_tool_names.clear() diff --git a/trae_agent/agent/trae_agent.py b/trae_agent/agent/trae_agent.py index c414117c7..f035d7265 100644 --- a/trae_agent/agent/trae_agent.py +++ b/trae_agent/agent/trae_agent.py @@ -107,6 +107,7 @@ def new_task( tool_names: list[str] | None = None, ): """Create a new task.""" + self.reset_tool_confirmation_state() self._task: str = task if tool_names is None and len(self._tools) == 0: diff --git a/trae_agent/cli.py b/trae_agent/cli.py index 3c20b8204..42daa8e20 100644 --- a/trae_agent/cli.py +++ b/trae_agent/cli.py @@ -189,6 +189,12 @@ def cli(): help="Type of agent to use (trae_agent)", default="trae_agent", ) +@click.option( + "--confirm-tools", + is_flag=True, + default=False, + help="Require user confirmation before executing tool calls", +) def run( task: str | None, file_path: str | None, @@ -204,6 +210,7 @@ def run( trajectory_file: str | None = None, console_type: str | None = "simple", agent_type: str | None = "trae_agent", + confirm_tools: bool = False, # --- Add Docker Mode --- docker_image: str | None = None, docker_container_id: str | None = None, @@ -305,6 +312,15 @@ def run( console.print("[red]Error: agent_type is required.[/red]") sys.exit(1) + # Apply --confirm-tools flag to config + if confirm_tools and config.trae_agent: + from trae_agent.utils.config import ToolConfirmationConfig + + config.trae_agent.tool_confirmation = ToolConfirmationConfig( + enabled=True, + tools_requiring_confirmation=["bash", "str_replace_based_edit_tool", "json_edit_tool"], + ) + # Create CLI Console console_mode = ConsoleMode.RUN if console_type: @@ -437,6 +453,12 @@ def run( help="Type of agent to use (trae_agent)", default="trae_agent", ) +@click.option( + "--confirm-tools", + is_flag=True, + default=False, + help="Require user confirmation before executing tool calls", +) def interactive( provider: str | None = None, model: str | None = None, @@ -447,6 +469,7 @@ def interactive( trajectory_file: str | None = None, console_type: str | None = "simple", agent_type: str | None = "trae_agent", + confirm_tools: bool = False, ): """ This function starts an interactive session with Trae Agent. @@ -472,6 +495,15 @@ def interactive( console.print("[red]Error: trae_agent configuration is required in the config file.[/red]") sys.exit(1) + # Apply --confirm-tools flag to config + if confirm_tools and config.trae_agent: + from trae_agent.utils.config import ToolConfirmationConfig + + config.trae_agent.tool_confirmation = ToolConfirmationConfig( + enabled=True, + tools_requiring_confirmation=["bash", "str_replace_based_edit_tool", "json_edit_tool"], + ) + # Create CLI Console for interactive mode console_mode = ConsoleMode.INTERACTIVE if console_type: diff --git a/trae_agent/utils/cli/__init__.py b/trae_agent/utils/cli/__init__.py index b6895aeff..9bd77743d 100644 --- a/trae_agent/utils/cli/__init__.py +++ b/trae_agent/utils/cli/__init__.py @@ -3,7 +3,7 @@ """CLI console module for Trae Agent.""" -from .cli_console import CLIConsole, ConsoleMode, ConsoleType +from .cli_console import CLIConsole, ConsoleMode, ConsoleType, ToolConfirmationResult from .console_factory import ConsoleFactory from .rich_console import RichCLIConsole from .simple_console import SimpleCLIConsole @@ -12,6 +12,7 @@ "CLIConsole", "ConsoleMode", "ConsoleType", + "ToolConfirmationResult", "SimpleCLIConsole", "RichCLIConsole", "ConsoleFactory", diff --git a/trae_agent/utils/cli/cli_console.py b/trae_agent/utils/cli/cli_console.py index efcdf8171..5c3ebe0fb 100644 --- a/trae_agent/utils/cli/cli_console.py +++ b/trae_agent/utils/cli/cli_console.py @@ -12,10 +12,19 @@ from rich.table import Table from trae_agent.agent.agent_basics import AgentExecution, AgentStep, AgentStepState +from trae_agent.tools.base import ToolCall from trae_agent.utils.config import LakeviewConfig from trae_agent.utils.lake_view import LakeView +class ToolConfirmationResult(Enum): + """Result of a tool confirmation request.""" + + APPROVE = "approve" + REJECT = "reject" + APPROVE_ALL = "approve_all" # Approve this and all matching future tool calls + + class ConsoleMode(Enum): """Console operation modes.""" @@ -115,6 +124,18 @@ def stop(self): """Stop the console and cleanup resources.""" pass + @abstractmethod + def get_tool_confirmation(self, tool_call: ToolCall) -> ToolConfirmationResult: + """Ask the user for confirmation before executing a tool call. + + Args: + tool_call: The tool call that is about to be executed. + + Returns: + ToolConfirmationResult indicating user's decision. + """ + pass + def set_lakeview(self, lakeview_config: LakeviewConfig | None = None): """Set the lakeview configuration for the console.""" if lakeview_config: diff --git a/trae_agent/utils/cli/rich_console.py b/trae_agent/utils/cli/rich_console.py index a7600c99b..d5327ac9f 100644 --- a/trae_agent/utils/cli/rich_console.py +++ b/trae_agent/utils/cli/rich_console.py @@ -22,6 +22,7 @@ CLIConsole, ConsoleMode, ConsoleStep, + ToolConfirmationResult, generate_agent_step_table, ) from trae_agent.utils.config import LakeviewConfig @@ -348,6 +349,41 @@ def get_working_dir_input(self) -> str: # For now, return current directory. Could be enhanced with a dialog return os.getcwd() + @override + def get_tool_confirmation(self, tool_call) -> ToolConfirmationResult: + """Ask the user for confirmation before executing a tool call.""" + tool_name = tool_call.name + if tool_name == "bash": + command = tool_call.arguments.get("command", "") + detail = f"[bold]Command:[/bold] {command}" + else: + detail = f"[bold]Arguments:[/bold] {tool_call.arguments}" + + # Display in the TUI log + if self.app and self.app.execution_log: + _ = self.app.execution_log.write( + Panel( + f"[bold]Tool:[/bold] {tool_name}\n{detail}", + title="Tool Confirmation Required", + border_style="yellow", + ) + ) + _ = self.app.execution_log.write( + "[bold]Options:[/bold] (y)es / (n)o / (a)lways approve this pattern" + ) + + while True: + try: + response = input("[y/n/a]: ").strip().lower() + if response in ("y", "yes"): + return ToolConfirmationResult.APPROVE + elif response in ("n", "no"): + return ToolConfirmationResult.REJECT + elif response in ("a", "always", "all"): + return ToolConfirmationResult.APPROVE_ALL + except (EOFError, KeyboardInterrupt): + return ToolConfirmationResult.REJECT + @override def stop(self): """Stop the console and cleanup resources.""" diff --git a/trae_agent/utils/cli/simple_console.py b/trae_agent/utils/cli/simple_console.py index a613105b3..9c2e36e00 100644 --- a/trae_agent/utils/cli/simple_console.py +++ b/trae_agent/utils/cli/simple_console.py @@ -17,6 +17,7 @@ CLIConsole, ConsoleMode, ConsoleStep, + ToolConfirmationResult, generate_agent_step_table, ) from trae_agent.utils.config import LakeviewConfig @@ -213,6 +214,42 @@ def get_working_dir_input(self) -> str: except (EOFError, KeyboardInterrupt): return "" + @override + def get_tool_confirmation(self, tool_call) -> ToolConfirmationResult: + """Ask the user for confirmation before executing a tool call.""" + tool_name = tool_call.name + # Show command for bash, arguments summary for others + if tool_name == "bash": + command = tool_call.arguments.get("command", "") + detail = f"[bold]Command:[/bold] {command}" + else: + detail = f"[bold]Arguments:[/bold] {tool_call.arguments}" + + self.console.print( + Panel( + f"[bold]Tool:[/bold] {tool_name}\n{detail}", + title="Tool Confirmation Required", + border_style="yellow", + ) + ) + self.console.print( + "[bold]Options:[/bold] (y)es / (n)o / (a)lways approve this pattern" + ) + + while True: + try: + response = input("[y/n/a]: ").strip().lower() + if response in ("y", "yes"): + return ToolConfirmationResult.APPROVE + elif response in ("n", "no"): + return ToolConfirmationResult.REJECT + elif response in ("a", "always", "all"): + return ToolConfirmationResult.APPROVE_ALL + else: + self.console.print("[yellow]Please enter y, n, or a[/yellow]") + except (EOFError, KeyboardInterrupt): + return ToolConfirmationResult.REJECT + @override def stop(self): """Stop the console and cleanup resources.""" diff --git a/trae_agent/utils/config.py b/trae_agent/utils/config.py index d95026901..93dcd4a0e 100644 --- a/trae_agent/utils/config.py +++ b/trae_agent/utils/config.py @@ -142,6 +142,17 @@ class MCPServerConfig: description: str | None = None +@dataclass +class ToolConfirmationConfig: + """Configuration for tool confirmation behavior.""" + + enabled: bool = False + # List of tool names that require confirmation. + # If None, ALL tools require confirmation when enabled. + # If a non-empty list, only those named tools require confirmation. + tools_requiring_confirmation: list[str] | None = None + + @dataclass class AgentConfig: """ @@ -153,6 +164,7 @@ class AgentConfig: max_steps: int model: ModelConfig tools: list[str] + tool_confirmation: ToolConfirmationConfig = field(default_factory=ToolConfirmationConfig) @dataclass @@ -284,8 +296,15 @@ def create( raise ConfigError(f"Model {agent_model_name} not found") from e match agent_name: case "trae_agent": + # Extract tool_confirmation config before **agent_config + # because nested dicts are not auto-converted to dataclasses + tc_config = agent_config.pop("tool_confirmation", {}) + tool_confirmation = ( + ToolConfirmationConfig(**tc_config) if tc_config else ToolConfirmationConfig() + ) trae_agent_config = TraeAgentConfig( **agent_config, + tool_confirmation=tool_confirmation, mcp_servers_config=mcp_servers_config, allow_mcp_servers=allow_mcp_servers, ) diff --git a/trae_config.yaml.example b/trae_config.yaml.example index 21f9d4d20..b777e31f8 100644 --- a/trae_config.yaml.example +++ b/trae_config.yaml.example @@ -8,6 +8,12 @@ agents: - str_replace_based_edit_tool - sequentialthinking - task_done + # tool_confirmation: + # enabled: true + # tools_requiring_confirmation: + # - bash + # - str_replace_based_edit_tool + # - json_edit_tool allow_mcp_servers: - playwright mcp_servers: From 258985f265fa9f0f2fce87b8036a5068e31a75ca Mon Sep 17 00:00:00 2001 From: newcat22 <2289063896@qq.com> Date: Sun, 3 May 2026 19:30:52 +0800 Subject: [PATCH 2/2] test(agent): add e2e tests for tool confirmation Add tests for YAML config parsing, CLI --confirm-tools flag, approved tool execution, approve-all prefix matching, and tool exclusion from confirmation list. --- tests/agent/test_tool_confirmation.py | 160 +++++++++++++++++++++++++- 1 file changed, 159 insertions(+), 1 deletion(-) diff --git a/tests/agent/test_tool_confirmation.py b/tests/agent/test_tool_confirmation.py index 164a868db..5a8c34054 100644 --- a/tests/agent/test_tool_confirmation.py +++ b/tests/agent/test_tool_confirmation.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch from trae_agent.agent.trae_agent import TraeAgent -from trae_agent.tools.base import ToolCall +from trae_agent.tools.base import ToolCall, ToolResult from trae_agent.utils.cli.cli_console import ToolConfirmationResult from trae_agent.utils.config import Config, ToolConfirmationConfig from trae_agent.utils.legacy_config import LegacyConfig @@ -302,6 +302,164 @@ def test_no_console_skips_confirmation(self): # Should not raise - just passes through self.assertIsNone(self.agent._cli_console) + def test_approved_tool_executes_normally(self): + """When user approves a tool call, it should execute and return the result.""" + from trae_agent.agent.agent_basics import AgentStep, AgentStepState + + self.agent._tool_confirmation_config = ToolConfirmationConfig( + enabled=True, tools_requiring_confirmation=["bash"] + ) + mock_console = MagicMock() + mock_console.get_tool_confirmation.return_value = ToolConfirmationResult.APPROVE + self.agent._cli_console = mock_console + + # Mock the tool executor to return a successful result + mock_result = ToolResult(call_id="1", name="bash", success=True, result="file1.txt\nfile2.txt") + mock_executor = MagicMock() + + async def fake_sequential(calls): + return [mock_result] + + mock_executor.sequential_tool_call = fake_sequential + self.agent._tool_caller = mock_executor + + tool_call = ToolCall(name="bash", call_id="1", arguments={"command": "ls"}) + + import asyncio + + step = AgentStep(step_number=1, state=AgentStepState.THINKING) + messages = asyncio.get_event_loop().run_until_complete( + self.agent._tool_call_handler([tool_call], step) + ) + + # Should have executed the tool and returned its result + self.assertEqual(len(messages), 1) + self.assertTrue(step.tool_results[0].success) + self.assertEqual(step.tool_results[0].result, "file1.txt\nfile2.txt") + + def test_approve_all_adds_prefix_pattern(self): + """When user approves with APPROVE_ALL for bash, subsequent matching commands auto-pass.""" + from trae_agent.agent.agent_basics import AgentStep, AgentStepState + + self.agent._tool_confirmation_config = ToolConfirmationConfig( + enabled=True, tools_requiring_confirmation=["bash"] + ) + + tool_call_1 = ToolCall(name="bash", call_id="1", arguments={"command": "pip install requests"}) + + # First call: user picks APPROVE_ALL + mock_console = MagicMock() + mock_console.get_tool_confirmation.return_value = ToolConfirmationResult.APPROVE_ALL + self.agent._cli_console = mock_console + + # "pip install" prefix should now be in allowed list + self.agent._add_allowed_pattern(tool_call_1) + self.assertIn("pip install", self.agent._allowed_command_prefixes) + + # Second call with matching prefix should be auto-allowed without prompting + tool_call_2 = ToolCall(name="bash", call_id="2", arguments={"command": "pip install flask"}) + self.assertTrue(self.agent._is_tool_call_allowed(tool_call_2)) + + # Non-matching command should still require confirmation + tool_call_3 = ToolCall(name="bash", call_id="3", arguments={"command": "pip uninstall flask"}) + self.assertFalse(self.agent._is_tool_call_allowed(tool_call_3)) + + def test_tool_not_in_list_skips_confirmation(self): + """Tools not in tools_requiring_confirmation should skip confirmation.""" + self.agent._tool_confirmation_config = ToolConfirmationConfig( + enabled=True, tools_requiring_confirmation=["bash"] + ) + # sequentialthinking is not in the list, should not require confirmation + self.assertFalse(self.agent._should_confirm_tool("sequentialthinking")) + self.assertTrue(self.agent._should_confirm_tool("bash")) + + +class TestYAMLConfigParsing(unittest.TestCase): + def test_tool_confirmation_from_yaml(self): + """Test that tool_confirmation config is correctly parsed from YAML.""" + yaml_string = """ +model_providers: + anthropic: + api_key: test-key + provider: anthropic +models: + trae_agent_model: + model_provider: anthropic + model: claude-sonnet-4-20250514 + max_tokens: 4096 + temperature: 0.5 + top_p: 1 + top_k: 0 + parallel_tool_calls: true + max_retries: 10 +agents: + trae_agent: + model: trae_agent_model + enable_lakeview: false + max_steps: 20 + tools: + - bash + - task_done + tool_confirmation: + enabled: true + tools_requiring_confirmation: + - bash + - str_replace_based_edit_tool +""" + config = Config.create(config_string=yaml_string) + self.assertIsNotNone(config.trae_agent) + self.assertTrue(config.trae_agent.tool_confirmation.enabled) + self.assertEqual( + config.trae_agent.tool_confirmation.tools_requiring_confirmation, + ["bash", "str_replace_based_edit_tool"], + ) + + def test_tool_confirmation_default_when_missing_from_yaml(self): + """Test that tool_confirmation defaults to disabled when not in YAML.""" + yaml_string = """ +model_providers: + anthropic: + api_key: test-key + provider: anthropic +models: + trae_agent_model: + model_provider: anthropic + model: claude-sonnet-4-20250514 + max_tokens: 4096 + temperature: 0.5 + top_p: 1 + top_k: 0 + parallel_tool_calls: true + max_retries: 10 +agents: + trae_agent: + model: trae_agent_model + enable_lakeview: false + max_steps: 20 + tools: + - bash +""" + config = Config.create(config_string=yaml_string) + self.assertIsNotNone(config.trae_agent) + self.assertFalse(config.trae_agent.tool_confirmation.enabled) + self.assertIsNone(config.trae_agent.tool_confirmation.tools_requiring_confirmation) + + +class TestCLIConfirmToolsFlag(unittest.TestCase): + def test_confirm_tools_flag_sets_config(self): + """Test that --confirm-tools CLI flag correctly enables tool confirmation.""" + from click.testing import CliRunner + + from trae_agent.cli import cli + + runner = CliRunner() + # Use --help to avoid needing real config/API, just verify the option exists + result = runner.invoke(cli, ["run", "--help"]) + self.assertIn("--confirm-tools", result.output) + + result = runner.invoke(cli, ["interactive", "--help"]) + self.assertIn("--confirm-tools", result.output) + if __name__ == "__main__": unittest.main()