diff --git a/tests/tools/test_edit_tool.py b/tests/tools/test_edit_tool.py index cf948df47..aeb14fbc2 100644 --- a/tests/tools/test_edit_tool.py +++ b/tests/tools/test_edit_tool.py @@ -3,7 +3,8 @@ import unittest from pathlib import Path -from unittest.mock import AsyncMock, patch +from tempfile import TemporaryDirectory +from unittest.mock import patch from trae_agent.tools.base import ToolCallArguments from trae_agent.tools.edit_tool import TextEditorTool @@ -102,13 +103,36 @@ async def test_str_replace_success(self): async def test_view_directory(self): self.mock_file_system(exists=True, is_dir=True) - with patch("trae_agent.tools.edit_tool.run", new_callable=AsyncMock) as mock_run: - mock_run.return_value = (0, "file1\nfile2", "") + with patch.object( + self.tool, + "_iter_visible_paths", + return_value=[self.test_dir, self.test_dir / "file1", self.test_dir / "file2"], + ): result = await self.tool.execute( ToolCallArguments({"command": "view", "path": str(self.test_dir)}) ) self.assertIn("files and directories", result.output) + async def test_view_directory_handles_shell_metacharacters(self): + with TemporaryDirectory() as temp_dir: + root = Path(temp_dir) / "workspace; touch injected" + root.mkdir() + (root / "file.txt").write_text("content") + (root / ".hidden").write_text("hidden") + child_dir = root / "child" + child_dir.mkdir() + (child_dir / "nested.txt").write_text("nested") + + result = await self.tool.execute( + ToolCallArguments({"command": "view", "path": str(root)}) + ) + + self.assertEqual(result.error_code, 0) + self.assertIn(str(root / "file.txt"), result.output) + self.assertIn(str(child_dir / "nested.txt"), result.output) + self.assertNotIn(".hidden", result.output) + self.assertFalse((Path(temp_dir) / "injected").exists()) + async def test_view_file(self): self.mock_file_system(exists=True, is_dir=False, content="line1\nline2\nline3") result = await self.tool.execute( diff --git a/trae_agent/tools/edit_tool.py b/trae_agent/tools/edit_tool.py index 3185b574d..0f6c3f68a 100644 --- a/trae_agent/tools/edit_tool.py +++ b/trae_agent/tools/edit_tool.py @@ -13,7 +13,7 @@ from typing import override from trae_agent.tools.base import Tool, ToolCallArguments, ToolError, ToolExecResult, ToolParameter -from trae_agent.tools.run import maybe_truncate, run +from trae_agent.tools.run import maybe_truncate EditToolSubCommands = [ "view", @@ -159,10 +159,16 @@ async def _view(self, path: Path, view_range: list[int] | None = None) -> ToolEx "The `view_range` parameter is not allowed when `path` points to a directory." ) - return_code, stdout, stderr = await run(rf"find {path} -maxdepth 2 -not -path '*/\.*'") - if not stderr: - stdout = f"Here's the files and directories up to 2 levels deep in {path}, excluding hidden items:\n{stdout}\n" - return ToolExecResult(error_code=return_code, output=stdout, error=stderr) + try: + stdout = "\n".join(str(file_path) for file_path in self._iter_visible_paths(path)) + except OSError as exc: + return ToolExecResult(error_code=1, error=str(exc)) + + stdout = ( + f"Here's the files and directories up to 2 levels deep in {path}, " + f"excluding hidden items:\n{stdout}\n" + ) + return ToolExecResult(error_code=0, output=stdout) file_content = self.read_file(path) init_line = 1 @@ -307,6 +313,20 @@ def _make_output( f"Here's the result of running `cat -n` on {file_descriptor}:\n" + file_content + "\n" ) + @staticmethod + def _iter_visible_paths(path: Path): + yield path + for child in path.iterdir(): + if child.name.startswith("."): + continue + yield child + if not child.is_dir(): + continue + for grandchild in child.iterdir(): + if grandchild.name.startswith("."): + continue + yield grandchild + async def _view_handler(self, arguments: ToolCallArguments, _path: Path) -> ToolExecResult: view_range = arguments.get("view_range", None) if view_range is None: diff --git a/trae_agent/tools/edit_tool_cli.py b/trae_agent/tools/edit_tool_cli.py index 78ebae50a..28d5a88b3 100644 --- a/trae_agent/tools/edit_tool_cli.py +++ b/trae_agent/tools/edit_tool_cli.py @@ -203,10 +203,16 @@ async def _view(self, path: Path, view_range: list[int] | None = None) -> ToolEx "The `view_range` parameter is not allowed when `path` points to a directory." ) - return_code, stdout, stderr = await run(rf"find {path} -maxdepth 2 -not -path '*/\.*'") - if not stderr: - stdout = f"Here's the files and directories up to 2 levels deep in {path}, excluding hidden items:\n{stdout}\n" - return ToolExecResult(error_code=return_code, output=stdout, error=stderr) + try: + stdout = "\n".join(str(file_path) for file_path in self._iter_visible_paths(path)) + except OSError as exc: + return ToolExecResult(error_code=1, error=str(exc)) + + stdout = ( + f"Here's the files and directories up to 2 levels deep in {path}, " + f"excluding hidden items:\n{stdout}\n" + ) + return ToolExecResult(error_code=0, output=stdout) file_content = self.read_file(path) init_line = 1 @@ -351,6 +357,20 @@ def _make_output( f"Here's the result of running `cat -n` on {file_descriptor}:\n" + file_content + "\n" ) + @staticmethod + def _iter_visible_paths(path: Path): + yield path + for child in path.iterdir(): + if child.name.startswith("."): + continue + yield child + if not child.is_dir(): + continue + for grandchild in child.iterdir(): + if grandchild.name.startswith("."): + continue + yield grandchild + async def _view_handler(self, arguments: ToolCallArguments, _path: Path) -> ToolExecResult: view_range = arguments.get("view_range", None) if view_range is None: