Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions tests/tools/test_edit_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import unittest
from pathlib import Path
from unittest.mock import AsyncMock, patch
from tempfile import TemporaryDirectory
from unittest.mock import patch

from trae_agent.tools.base import ToolCallArguments
from trae_agent.tools.edit_tool import TextEditorTool
Expand Down Expand Up @@ -102,13 +103,36 @@ async def test_str_replace_success(self):

async def test_view_directory(self):
self.mock_file_system(exists=True, is_dir=True)
with patch("trae_agent.tools.edit_tool.run", new_callable=AsyncMock) as mock_run:
mock_run.return_value = (0, "file1\nfile2", "")
with patch.object(
self.tool,
"_iter_visible_paths",
return_value=[self.test_dir, self.test_dir / "file1", self.test_dir / "file2"],
):
result = await self.tool.execute(
ToolCallArguments({"command": "view", "path": str(self.test_dir)})
)
self.assertIn("files and directories", result.output)

async def test_view_directory_handles_shell_metacharacters(self):
with TemporaryDirectory() as temp_dir:
root = Path(temp_dir) / "workspace; touch injected"
root.mkdir()
(root / "file.txt").write_text("content")
(root / ".hidden").write_text("hidden")
child_dir = root / "child"
child_dir.mkdir()
(child_dir / "nested.txt").write_text("nested")

result = await self.tool.execute(
ToolCallArguments({"command": "view", "path": str(root)})
)

self.assertEqual(result.error_code, 0)
self.assertIn(str(root / "file.txt"), result.output)
self.assertIn(str(child_dir / "nested.txt"), result.output)
self.assertNotIn(".hidden", result.output)
self.assertFalse((Path(temp_dir) / "injected").exists())

async def test_view_file(self):
self.mock_file_system(exists=True, is_dir=False, content="line1\nline2\nline3")
result = await self.tool.execute(
Expand Down
30 changes: 25 additions & 5 deletions trae_agent/tools/edit_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import override

from trae_agent.tools.base import Tool, ToolCallArguments, ToolError, ToolExecResult, ToolParameter
from trae_agent.tools.run import maybe_truncate, run
from trae_agent.tools.run import maybe_truncate

EditToolSubCommands = [
"view",
Expand Down Expand Up @@ -159,10 +159,16 @@ async def _view(self, path: Path, view_range: list[int] | None = None) -> ToolEx
"The `view_range` parameter is not allowed when `path` points to a directory."
)

return_code, stdout, stderr = await run(rf"find {path} -maxdepth 2 -not -path '*/\.*'")
if not stderr:
stdout = f"Here's the files and directories up to 2 levels deep in {path}, excluding hidden items:\n{stdout}\n"
return ToolExecResult(error_code=return_code, output=stdout, error=stderr)
try:
stdout = "\n".join(str(file_path) for file_path in self._iter_visible_paths(path))
except OSError as exc:
return ToolExecResult(error_code=1, error=str(exc))

stdout = (
f"Here's the files and directories up to 2 levels deep in {path}, "
f"excluding hidden items:\n{stdout}\n"
)
return ToolExecResult(error_code=0, output=stdout)

file_content = self.read_file(path)
init_line = 1
Expand Down Expand Up @@ -307,6 +313,20 @@ def _make_output(
f"Here's the result of running `cat -n` on {file_descriptor}:\n" + file_content + "\n"
)

@staticmethod
def _iter_visible_paths(path: Path):
yield path
for child in path.iterdir():
if child.name.startswith("."):
continue
yield child
if not child.is_dir():
continue
for grandchild in child.iterdir():
if grandchild.name.startswith("."):
continue
yield grandchild

async def _view_handler(self, arguments: ToolCallArguments, _path: Path) -> ToolExecResult:
view_range = arguments.get("view_range", None)
if view_range is None:
Expand Down
28 changes: 24 additions & 4 deletions trae_agent/tools/edit_tool_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,16 @@ async def _view(self, path: Path, view_range: list[int] | None = None) -> ToolEx
"The `view_range` parameter is not allowed when `path` points to a directory."
)

return_code, stdout, stderr = await run(rf"find {path} -maxdepth 2 -not -path '*/\.*'")
if not stderr:
stdout = f"Here's the files and directories up to 2 levels deep in {path}, excluding hidden items:\n{stdout}\n"
return ToolExecResult(error_code=return_code, output=stdout, error=stderr)
try:
stdout = "\n".join(str(file_path) for file_path in self._iter_visible_paths(path))
except OSError as exc:
return ToolExecResult(error_code=1, error=str(exc))

stdout = (
f"Here's the files and directories up to 2 levels deep in {path}, "
f"excluding hidden items:\n{stdout}\n"
)
return ToolExecResult(error_code=0, output=stdout)

file_content = self.read_file(path)
init_line = 1
Expand Down Expand Up @@ -351,6 +357,20 @@ def _make_output(
f"Here's the result of running `cat -n` on {file_descriptor}:\n" + file_content + "\n"
)

@staticmethod
def _iter_visible_paths(path: Path):
yield path
for child in path.iterdir():
if child.name.startswith("."):
continue
yield child
if not child.is_dir():
continue
for grandchild in child.iterdir():
if grandchild.name.startswith("."):
continue
yield grandchild

async def _view_handler(self, arguments: ToolCallArguments, _path: Path) -> ToolExecResult:
view_range = arguments.get("view_range", None)
if view_range is None:
Expand Down