Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions tests/tools/test_ckg_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT

import unittest
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest.mock import MagicMock, patch

from trae_agent.tools.base import ToolCallArguments
from trae_agent.tools.ckg_tool import CKGTool


class TestCKGTool(unittest.IsolatedAsyncioTestCase):
async def test_execute_accepts_file_path_inside_codebase(self):
with TemporaryDirectory() as tmpdir:
code_file = Path(tmpdir) / "example.py"
code_file.write_text("def target():\n return 1\n")

ckg_database = MagicMock()
ckg_database.query_function.return_value = []

with patch(
"trae_agent.tools.ckg_tool.CKGDatabase", return_value=ckg_database
) as ckg_database_cls:
result = await CKGTool().execute(
ToolCallArguments(
{
"command": "search_function",
"path": str(code_file),
"identifier": "target",
"print_body": False,
}
)
)

self.assertIsNone(result.error)
self.assertEqual(result.output, "No functions named target found.")
ckg_database_cls.assert_called_once_with(code_file.parent)

async def test_execute_uses_git_root_for_file_path(self):
with TemporaryDirectory() as tmpdir:
codebase_root = Path(tmpdir)
nested_dir = codebase_root / "pkg"
nested_dir.mkdir()
code_file = nested_dir / "example.py"
code_file.write_text("def target():\n return 1\n")

ckg_database = MagicMock()
ckg_database.query_function.return_value = []
git_result = MagicMock(returncode=0, stdout=f"{codebase_root}\n")

with (
patch("trae_agent.tools.ckg_tool.subprocess.run", return_value=git_result),
patch(
"trae_agent.tools.ckg_tool.CKGDatabase", return_value=ckg_database
) as ckg_database_cls,
):
result = await CKGTool().execute(
ToolCallArguments(
{
"command": "search_function",
"path": str(code_file),
"identifier": "target",
}
)
)

self.assertIsNone(result.error)
ckg_database_cls.assert_called_once_with(codebase_root)


if __name__ == "__main__":
unittest.main()
24 changes: 22 additions & 2 deletions trae_agent/tools/ckg_tool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT

import subprocess
from pathlib import Path
from typing import override

Expand Down Expand Up @@ -60,7 +61,7 @@ def get_parameters(self) -> list[ToolParameter]:
ToolParameter(
name="path",
type="string",
description="The path to the codebase.",
description="The path to the codebase root or to a file inside the codebase.",
required=True,
),
ToolParameter(
Expand Down Expand Up @@ -105,7 +106,9 @@ async def execute(self, arguments: ToolCallArguments) -> ToolExecResult:
error=f"Codebase path {path} does not exist",
error_code=-1,
)
if not codebase_path.is_dir():
if codebase_path.is_file():
codebase_path = self._resolve_codebase_path_from_file(codebase_path)
elif not codebase_path.is_dir():
return ToolExecResult(
error=f"Codebase path {path} is not a directory",
error_code=-1,
Expand All @@ -132,6 +135,23 @@ async def execute(self, arguments: ToolCallArguments) -> ToolExecResult:
case _:
return ToolExecResult(error=f"Invalid command: {command}", error_code=-1)

def _resolve_codebase_path_from_file(self, file_path: Path) -> Path:
parent = file_path.parent
try:
result = subprocess.run(
["git", "rev-parse", "--show-toplevel"],
cwd=parent,
capture_output=True,
text=True,
timeout=5,
check=False,
)
if result.returncode == 0 and result.stdout.strip():
return Path(result.stdout.strip())
except (FileNotFoundError, subprocess.TimeoutExpired):
pass
return parent

def _search_function(
self, ckg_database: CKGDatabase, identifier: str, print_body: bool = True
) -> str:
Expand Down