diff --git a/tests/utils/test_long_term_memory.py b/tests/utils/test_long_term_memory.py new file mode 100644 index 000000000..a904d8464 --- /dev/null +++ b/tests/utils/test_long_term_memory.py @@ -0,0 +1,659 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import unittest +from pathlib import Path +from unittest.mock import MagicMock + +from trae_agent.agent.agent_basics import AgentStep, AgentStepState +from trae_agent.utils.config import Config, LongTermMemoryConfig, ModelConfig, ModelProvider +from trae_agent.utils.llm_clients.llm_basics import LLMMessage, LLMResponse +from trae_agent.utils.long_term_memory import ( + LongTermMemory, + MemoryDocument, + MemorySection, +) +from trae_agent.utils.memory_trigger import ( + ManualMemoryTrigger, + PeriodicMemoryTrigger, + create_memory_trigger, +) + + +class TestMemoryDocument(unittest.TestCase): + def test_to_markdown(self): + doc = MemoryDocument( + task_name="Fix login bug", + sections=[ + MemorySection(step_range="1-3", problem="Login fails", conclusion="Regex issue"), + MemorySection(step_range="4-5", problem="Fix regex", conclusion="Used re.escape"), + ], + created_at="2025-01-01 00:00:00", + step_count=5, + ) + md = doc.to_markdown() + self.assertIn("# Long-term Memory — Task: Fix login bug", md) + self.assertIn("## Step 1-3", md) + self.assertIn("**Problem**: Login fails", md) + self.assertIn("**Conclusion**: Regex issue", md) + self.assertIn("## Step 4-5", md) + self.assertIn("**Problem**: Fix regex", md) + self.assertIn("**Conclusion**: Used re.escape", md) + self.assertIn("Generated: 2025-01-01 00:00:00 | Steps: 5", md) + + def test_from_markdown_round_trip(self): + doc = MemoryDocument( + task_name="Fix login bug", + sections=[ + MemorySection(step_range="1-3", problem="Login fails", conclusion="Regex issue"), + ], + created_at="2025-01-01 00:00:00", + step_count=3, + ) + md = doc.to_markdown() + parsed = MemoryDocument.from_markdown(md) + self.assertEqual(parsed.task_name, "Fix login bug") + self.assertEqual(len(parsed.sections), 1) + self.assertEqual(parsed.sections[0].step_range, "1-3") + self.assertEqual(parsed.sections[0].problem, "Login fails") + self.assertEqual(parsed.sections[0].conclusion, "Regex issue") + + def test_from_markdown_unknown_task(self): + parsed = MemoryDocument.from_markdown("# unrelated content") + self.assertEqual(parsed.task_name, "Unknown") + self.assertEqual(len(parsed.sections), 0) + + def test_memory_section_heading(self): + section = MemorySection(step_range="1-5", problem="x", conclusion="y") + self.assertEqual(section.heading(), "Step 1-5") + + +class TestMemoryTrigger(unittest.TestCase): + def test_manual_trigger_never_fires(self): + trigger = ManualMemoryTrigger() + step = AgentStep(step_number=1, state=AgentStepState.COMPLETED) + self.assertFalse(trigger.should_trigger(step, 1)) + self.assertFalse(trigger.should_trigger(step, 100)) + self.assertEqual(trigger.trigger_type_name(), "manual") + + def test_periodic_trigger_fires_at_interval(self): + trigger = PeriodicMemoryTrigger(interval=5) + step = AgentStep(step_number=5, state=AgentStepState.COMPLETED) + self.assertTrue(trigger.should_trigger(step, 5)) + # Should not fire again at same count + self.assertFalse(trigger.should_trigger(step, 5)) + + def test_periodic_trigger_does_not_fire_between_intervals(self): + trigger = PeriodicMemoryTrigger(interval=10) + step = AgentStep(step_number=3, state=AgentStepState.COMPLETED) + self.assertFalse(trigger.should_trigger(step, 3)) + self.assertFalse(trigger.should_trigger(step, 7)) + + def test_periodic_trigger_multiple_firings(self): + trigger = PeriodicMemoryTrigger(interval=3) + step = AgentStep(step_number=3, state=AgentStepState.COMPLETED) + self.assertTrue(trigger.should_trigger(step, 3)) + self.assertFalse(trigger.should_trigger(step, 3)) + self.assertTrue(trigger.should_trigger(step, 6)) + self.assertFalse(trigger.should_trigger(step, 6)) + + def test_periodic_trigger_type_name(self): + trigger = PeriodicMemoryTrigger(interval=5) + self.assertEqual(trigger.trigger_type_name(), "periodic(every 5 steps)") + + def test_create_memory_trigger_manual(self): + trigger = create_memory_trigger("manual") + self.assertIsInstance(trigger, ManualMemoryTrigger) + + def test_create_memory_trigger_periodic(self): + trigger = create_memory_trigger("periodic", periodic_interval=7) + self.assertIsInstance(trigger, PeriodicMemoryTrigger) + self.assertEqual(trigger.trigger_type_name(), "periodic(every 7 steps)") + + def test_create_memory_trigger_unknown(self): + with self.assertRaises(ValueError): + create_memory_trigger("unknown") + + +class TestLongTermMemoryConfig(unittest.TestCase): + def test_default_values(self): + config = LongTermMemoryConfig() + self.assertFalse(config.enabled) + self.assertEqual(config.trigger_type, "manual") + self.assertEqual(config.periodic_interval, 10) + self.assertEqual(config.output_dir, "memory/") + self.assertIsNone(config.model) + + def test_custom_values(self): + config = LongTermMemoryConfig( + enabled=True, + trigger_type="periodic", + periodic_interval=5, + output_dir="my_memory/", + ) + self.assertTrue(config.enabled) + self.assertEqual(config.trigger_type, "periodic") + self.assertEqual(config.periodic_interval, 5) + self.assertEqual(config.output_dir, "my_memory/") + + def test_yaml_parsing(self): + yaml_str = """ +model_providers: + openai: + api_key: test-key + provider: openai +models: + default_model: + model_provider: openai + model: gpt-4o + max_tokens: 4096 + temperature: 0.5 + top_p: 1 + top_k: 0 + max_retries: 5 + parallel_tool_calls: true +long_term_memory: + enabled: true + trigger_type: periodic + periodic_interval: 5 + output_dir: my_memory/ +agents: + trae_agent: + model: default_model + max_steps: 50 + enable_lakeview: false + tools: + - bash +""" + config = Config.create(config_string=yaml_str) + self.assertIsNotNone(config.trae_agent) + self.assertIsNotNone(config.trae_agent.long_term_memory) + ltm = config.trae_agent.long_term_memory + self.assertTrue(ltm.enabled) + self.assertEqual(ltm.trigger_type, "periodic") + self.assertEqual(ltm.periodic_interval, 5) + self.assertEqual(ltm.output_dir, "my_memory/") + + def test_yaml_parsing_disabled(self): + yaml_str = """ +model_providers: + openai: + api_key: test-key + provider: openai +models: + default_model: + model_provider: openai + model: gpt-4o + max_tokens: 4096 + temperature: 0.5 + top_p: 1 + top_k: 0 + max_retries: 5 + parallel_tool_calls: true +long_term_memory: + enabled: false +agents: + trae_agent: + model: default_model + max_steps: 50 + enable_lakeview: false + tools: + - bash +""" + config = Config.create(config_string=yaml_str) + self.assertIsNotNone(config.trae_agent.long_term_memory) + self.assertFalse(config.trae_agent.long_term_memory.enabled) + + +class TestBuildMemoryMessage(unittest.TestCase): + def test_build_message_with_sections(self): + model = ModelConfig( + model="gpt-4o", + model_provider=ModelProvider(api_key="test", provider="openai"), + max_tokens=4096, + temperature=0.5, + top_p=1, + top_k=0, + max_retries=5, + parallel_tool_calls=True, + ) + ltm = LongTermMemory( + config=LongTermMemoryConfig(enabled=True), + fallback_model=model, + ) + ltm._sections = [ + MemorySection(step_range="1-3", problem="Login fails", conclusion="Regex issue"), + ] + msg = ltm.build_memory_message() + self.assertIsNotNone(msg) + self.assertIn("Login fails", msg.content) + self.assertIn("Regex issue", msg.content) + self.assertEqual(msg.role, "user") + + def test_build_message_no_sections(self): + model = ModelConfig( + model="gpt-4o", + model_provider=ModelProvider(api_key="test", provider="openai"), + max_tokens=4096, + temperature=0.5, + top_p=1, + top_k=0, + max_retries=5, + parallel_tool_calls=True, + ) + ltm = LongTermMemory( + config=LongTermMemoryConfig(enabled=True), + fallback_model=model, + ) + msg = ltm.build_memory_message() + self.assertIsNone(msg) + + +class TestInjectMemoryIntoMessages(unittest.TestCase): + def _make_agent(self): + from trae_agent.agent.base_agent import BaseAgent + + class ConcreteAgent(BaseAgent): + def new_task(self, task, extra_args=None, tool_names=None): + pass + + async def cleanup_mcp_clients(self): + pass + + agent = ConcreteAgent.__new__(ConcreteAgent) + agent._long_term_memory = None + agent._memory_trigger = None + agent._max_steps = 10 + agent._tools = [] + agent._cli_console = None + agent._trajectory_recorder = None + agent._tool_confirmation_config = MagicMock(enabled=False) + agent._tool_confirmation_approved_all = False + agent._allowed_command_prefixes = [] + agent._allowed_tool_names = set() + return agent + + def test_inject_preserves_system_and_recent(self): + agent = self._make_agent() + ltm = MagicMock() + ltm.build_memory_message.return_value = LLMMessage( + role="user", content="Memory summary" + ) + agent._long_term_memory = ltm + + messages = [LLMMessage(role="system", content="system prompt")] + for i in range(25): + messages.append(LLMMessage(role="user", content=f"msg {i}")) + + result = agent.inject_memory_into_messages(messages, keep_recent=4) + self.assertEqual(result[0].content, "system prompt") + self.assertEqual(result[1].content, "Memory summary") + self.assertEqual(len(result), 6) # system + memory + 4 recent + + def test_no_injection_without_memory(self): + agent = self._make_agent() + messages = [LLMMessage(role="system", content="system prompt")] + for i in range(25): + messages.append(LLMMessage(role="user", content=f"msg {i}")) + + result = agent.inject_memory_into_messages(messages) + self.assertEqual(result, messages) + + +class TestExtractMemory(unittest.TestCase): + def test_extract_memory_with_mock_llm(self): + model = ModelConfig( + model="gpt-4o", + model_provider=ModelProvider(api_key="test", provider="openai"), + max_tokens=4096, + temperature=0.5, + top_p=1, + top_k=0, + max_retries=5, + parallel_tool_calls=True, + ) + ltm = LongTermMemory( + config=LongTermMemoryConfig(enabled=True), + fallback_model=model, + ) + ltm.set_task("Fix login bug") + + # Mock the LLM client + mock_response = LLMResponse( + content='\nLogin fails with special chars\nRegex in auth/validator.py:42 needs escaping\n', + model="gpt-4o", + ) + ltm._llm_client = MagicMock() + ltm._llm_client.chat = MagicMock(return_value=mock_response) + + steps = [ + AgentStep( + step_number=1, + state=AgentStepState.COMPLETED, + llm_response=LLMResponse(content="Looking at login code", model="gpt-4o"), + ), + AgentStep( + step_number=2, + state=AgentStepState.COMPLETED, + llm_response=LLMResponse(content="Found the issue", model="gpt-4o"), + ), + ] + + import asyncio + + doc = asyncio.run(ltm.extract_memory(steps)) + self.assertIsNotNone(doc) + self.assertEqual(len(doc.sections), 1) + self.assertEqual(doc.sections[0].step_range, "1-3") + self.assertIn("Login fails", doc.sections[0].problem) + self.assertIn("Regex", doc.sections[0].conclusion) + + def test_extract_and_save_creates_file(self): + import asyncio + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + model = ModelConfig( + model="gpt-4o", + model_provider=ModelProvider(api_key="test", provider="openai"), + max_tokens=4096, + temperature=0.5, + top_p=1, + top_k=0, + max_retries=5, + parallel_tool_calls=True, + ) + ltm = LongTermMemory( + config=LongTermMemoryConfig(enabled=True, output_dir=tmpdir), + fallback_model=model, + ) + ltm.set_task("Test task") + + mock_response = LLMResponse( + content='\nTest problem\nTest conclusion\n', + model="gpt-4o", + ) + ltm._llm_client = MagicMock() + ltm._llm_client.chat = MagicMock(return_value=mock_response) + + steps = [ + AgentStep( + step_number=1, + state=AgentStepState.COMPLETED, + llm_response=LLMResponse(content="Step 1", model="gpt-4o"), + ), + ] + + path = asyncio.run(ltm.extract_and_save(steps)) + self.assertIsNotNone(path) + import os + + self.assertTrue(os.path.exists(path)) + with open(path) as f: + content = f.read() + self.assertIn("Test task", content) + self.assertIn("Test problem", content) + self.assertIn("Test conclusion", content) + + +if __name__ == "__main__": + unittest.main() + + +class TestMemoryDocumentSessionId(unittest.TestCase): + def test_to_markdown_with_session_id(self): + doc = MemoryDocument( + task_name="Fix login bug", + session_id="session_20250101_120000", + sections=[ + MemorySection(step_range="1-3", problem="Login fails", conclusion="Regex issue"), + ], + created_at="2025-01-01 12:00:00", + step_count=3, + ) + md = doc.to_markdown() + self.assertIn("Session: session_20250101_120000", md) + + def test_to_markdown_without_session_id(self): + doc = MemoryDocument( + task_name="Fix login bug", + sections=[ + MemorySection(step_range="1-3", problem="Login fails", conclusion="Regex issue"), + ], + created_at="2025-01-01 12:00:00", + step_count=3, + ) + md = doc.to_markdown() + self.assertNotIn("Session:", md) + + def test_from_markdown_with_session_id(self): + md = """# Long-term Memory — Task: Fix login bug +Session: session_20250101_120000 +Generated: 2025-01-01 12:00:00 | Steps: 3 + +## Step 1-3 +**Problem**: Login fails +**Conclusion**: Regex issue +""" + doc = MemoryDocument.from_markdown(md) + self.assertEqual(doc.session_id, "session_20250101_120000") + + def test_from_markdown_without_session_id(self): + md = """# Long-term Memory — Task: Fix login bug +Generated: 2025-01-01 12:00:00 | Steps: 3 + +## Step 1-3 +**Problem**: Login fails +**Conclusion**: Regex issue +""" + doc = MemoryDocument.from_markdown(md) + self.assertEqual(doc.session_id, "") + + +class TestLoadMemory(unittest.TestCase): + def _make_ltm(self, tmpdir: str) -> LongTermMemory: + model = ModelConfig( + model="gpt-4o", + model_provider=ModelProvider(api_key="test", provider="openai"), + max_tokens=4096, + temperature=0.5, + top_p=1, + top_k=0, + max_retries=5, + parallel_tool_calls=True, + ) + return LongTermMemory( + config=LongTermMemoryConfig(enabled=True, output_dir=tmpdir), + fallback_model=model, + ) + + def test_load_memory_populates_preloaded_sections(self): + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + ltm = self._make_ltm(tmpdir) + # Write a memory file + md_path = f"{tmpdir}/test_memory.md" + with open(md_path, "w") as f: + f.write("# Long-term Memory — Task: Previous task\n\n## Step 1-2\n**Problem**: Old problem\n**Conclusion**: Old conclusion\n") + ltm.load_memory(md_path) + self.assertEqual(len(ltm._preloaded_sections), 1) + self.assertEqual(ltm._preloaded_sections[0].problem, "Old problem") + + def test_load_memory_file_not_found(self): + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + ltm = self._make_ltm(tmpdir) + with self.assertRaises(FileNotFoundError): + ltm.load_memory(f"{tmpdir}/nonexistent.md") + + def test_load_memory_empty_sections(self): + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + ltm = self._make_ltm(tmpdir) + md_path = f"{tmpdir}/empty.md" + with open(md_path, "w") as f: + f.write("# Long-term Memory — Task: Empty\n\nNo sections here.\n") + ltm.load_memory(md_path) + self.assertEqual(len(ltm._preloaded_sections), 0) + + def test_preloaded_sections_survive_set_task(self): + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + ltm = self._make_ltm(tmpdir) + md_path = f"{tmpdir}/test_memory.md" + with open(md_path, "w") as f: + f.write("# Long-term Memory — Task: Previous task\n\n## Step 1-2\n**Problem**: Old problem\n**Conclusion**: Old conclusion\n") + ltm.load_memory(md_path) + ltm._sections = [MemorySection(step_range="1-1", problem="Current", conclusion="Current")] + ltm.set_task("New task") + self.assertEqual(len(ltm._preloaded_sections), 1) + self.assertEqual(len(ltm._sections), 0) + + def test_build_memory_message_combines_both(self): + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + ltm = self._make_ltm(tmpdir) + md_path = f"{tmpdir}/test_memory.md" + with open(md_path, "w") as f: + f.write("# Long-term Memory — Task: Previous task\n\n## Step 1-2\n**Problem**: Old problem\n**Conclusion**: Old conclusion\n") + ltm.load_memory(md_path) + ltm._sections = [MemorySection(step_range="1-1", problem="Current problem", conclusion="Current conclusion")] + msg = ltm.build_memory_message() + self.assertIsNotNone(msg) + self.assertIn("Context from Previous Sessions", msg.content) + self.assertIn("Context from Current Session", msg.content) + self.assertIn("Old problem", msg.content) + self.assertIn("Current problem", msg.content) + + +class TestMemoryIndex(unittest.TestCase): + def _make_ltm(self, tmpdir: str) -> LongTermMemory: + model = ModelConfig( + model="gpt-4o", + model_provider=ModelProvider(api_key="test", provider="openai"), + max_tokens=4096, + temperature=0.5, + top_p=1, + top_k=0, + max_retries=5, + parallel_tool_calls=True, + ) + return LongTermMemory( + config=LongTermMemoryConfig(enabled=True, output_dir=tmpdir), + fallback_model=model, + ) + + def test_update_index_creates_entry(self): + import json + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + ltm = self._make_ltm(tmpdir) + ltm.set_session_id("session_test_001") + ltm.set_task("Test task") + ltm._update_index(f"{tmpdir}/memory_test.md") + + index = json.loads(Path(f"{tmpdir}/index.json").read_text()) + self.assertIn("session_test_001", index["sessions"]) + self.assertEqual(index["sessions"]["session_test_001"]["task_name"], "Test task") + + def test_update_index_appends_memory_file(self): + import json + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + ltm = self._make_ltm(tmpdir) + ltm.set_session_id("session_test_002") + ltm.set_task("Test task") + ltm._update_index(f"{tmpdir}/memory_1.md") + ltm._update_index(f"{tmpdir}/memory_2.md") + + index = json.loads(Path(f"{tmpdir}/index.json").read_text()) + self.assertEqual(len(index["sessions"]["session_test_002"]["memory_files"]), 2) + + def test_query_index_empty(self): + result = LongTermMemory.query_index("/nonexistent/path/index.json") + self.assertEqual(result, {"version": 1, "sessions": {}}) + + def test_set_trajectory_file(self): + import json + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + ltm = self._make_ltm(tmpdir) + ltm.set_session_id("session_test_003") + ltm.set_task("Test task") + ltm._update_index(f"{tmpdir}/memory_test.md") + ltm.set_trajectory_file("trajectories/trajectory_test.json") + + index = json.loads(Path(f"{tmpdir}/index.json").read_text()) + self.assertEqual( + index["sessions"]["session_test_003"]["trajectory_file"], + "trajectories/trajectory_test.json", + ) + + +class TestCrossSessionIntegration(unittest.TestCase): + def test_save_and_reload_memory(self): + import asyncio + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + # First session: extract and save + model = ModelConfig( + model="gpt-4o", + model_provider=ModelProvider(api_key="test", provider="openai"), + max_tokens=4096, + temperature=0.5, + top_p=1, + top_k=0, + max_retries=5, + parallel_tool_calls=True, + ) + ltm1 = LongTermMemory( + config=LongTermMemoryConfig(enabled=True, output_dir=tmpdir), + fallback_model=model, + ) + ltm1.set_session_id("session_first") + ltm1.set_task("First task") + + mock_response = LLMResponse( + content='\nFirst problem\nFirst conclusion\n', + model="gpt-4o", + ) + ltm1._llm_client = MagicMock() + ltm1._llm_client.chat = MagicMock(return_value=mock_response) + + steps = [ + AgentStep( + step_number=1, + state=AgentStepState.COMPLETED, + llm_response=LLMResponse(content="Step 1", model="gpt-4o"), + ), + ] + + path = asyncio.run(ltm1.extract_and_save(steps)) + self.assertIsNotNone(path) + + # Second session: load the memory + ltm2 = LongTermMemory( + config=LongTermMemoryConfig(enabled=True, output_dir=tmpdir), + fallback_model=model, + ) + ltm2.set_session_id("session_second") + ltm2.set_task("Second task") + ltm2.load_memory(path) + + self.assertEqual(len(ltm2._preloaded_sections), 1) + self.assertEqual(ltm2._preloaded_sections[0].problem, "First problem") + + # Verify build_memory_message includes both + ltm2._sections = [MemorySection(step_range="1-1", problem="New problem", conclusion="New conclusion")] + msg = ltm2.build_memory_message() + self.assertIn("First problem", msg.content) + self.assertIn("New problem", msg.content) diff --git a/trae_agent/agent/agent.py b/trae_agent/agent/agent.py index bbca94f01..f9d57d243 100644 --- a/trae_agent/agent/agent.py +++ b/trae_agent/agent/agent.py @@ -20,6 +20,8 @@ def __init__( cli_console: CLIConsole | None = None, docker_config: dict | None = None, docker_keep: bool = True, + session_id: str | None = None, + memory_path: str | None = None, ): if isinstance(agent_type, str): agent_type = AgentType(agent_type) @@ -56,6 +58,23 @@ def __init__( self.agent.set_trajectory_recorder(self.trajectory_recorder) + # Set up long-term memory + if config.trae_agent and config.trae_agent.long_term_memory and config.trae_agent.long_term_memory.enabled: + from trae_agent.utils.long_term_memory import LongTermMemory + + ltm = LongTermMemory( + config=config.trae_agent.long_term_memory, + fallback_model=config.trae_agent.model, + ) + self.agent.set_long_term_memory(ltm) + + # Set session ID and preload memory if provided + if self.agent.long_term_memory: + if session_id: + self.agent.long_term_memory.set_session_id(session_id) + if memory_path: + self.agent.long_term_memory.load_memory(memory_path) + async def run( self, task: str, @@ -64,6 +83,10 @@ async def run( ): self.agent.new_task(task, extra_args, tool_names) + if self.agent.long_term_memory: + self.agent.long_term_memory.set_task(task) + self.agent.long_term_memory.set_trajectory_file(self.trajectory_file) + if self.agent.allow_mcp_servers: if self.agent.cli_console: self.agent.cli_console.print("Initialising MCP tools...") diff --git a/trae_agent/agent/base_agent.py b/trae_agent/agent/base_agent.py index 01d4fde4c..3d9d44115 100644 --- a/trae_agent/agent/base_agent.py +++ b/trae_agent/agent/base_agent.py @@ -18,6 +18,8 @@ from trae_agent.utils.config import AgentConfig, ModelConfig from trae_agent.utils.llm_clients.llm_basics import LLMMessage, LLMResponse from trae_agent.utils.llm_clients.llm_client import LLMClient +from trae_agent.utils.long_term_memory import LongTermMemory +from trae_agent.utils.memory_trigger import MemoryTrigger from trae_agent.utils.trajectory_recorder import TrajectoryRecorder @@ -77,6 +79,11 @@ def __init__( # Trajectory recorder self._trajectory_recorder: TrajectoryRecorder | None = None + # Long-term memory + self._long_term_memory: LongTermMemory | None = None + self._memory_trigger: MemoryTrigger | None = None + self._current_execution: AgentExecution | None = None + # CKG tool-specific: clear the older CKG databases clear_older_ckg() @@ -95,6 +102,23 @@ def set_trajectory_recorder(self, recorder: TrajectoryRecorder | None) -> None: # Also set it on the LLM client self._llm_client.set_trajectory_recorder(recorder) + @property + def long_term_memory(self) -> LongTermMemory | None: + """Get the long-term memory system for this agent.""" + return self._long_term_memory + + def set_long_term_memory(self, ltm: LongTermMemory | None) -> None: + """Set the long-term memory system and its trigger.""" + from trae_agent.utils.memory_trigger import create_memory_trigger + + self._long_term_memory = ltm + if ltm is not None: + self._memory_trigger = create_memory_trigger( + ltm._config.trigger_type, ltm._config.periodic_interval + ) + else: + self._memory_trigger = None + @property def cli_console(self) -> CLIConsole | None: """Get the CLI console for this agent.""" @@ -153,6 +177,7 @@ async def execute_task(self) -> AgentExecution: start_time = time.time() execution = AgentExecution(task=self._task, steps=[]) + self._current_execution = execution step: AgentStep | None = None try: @@ -167,6 +192,16 @@ async def execute_task(self) -> AgentExecution: await self._finalize_step( step, messages, execution ) # record trajectory for this step and update the CLI console + # Check memory trigger + if self._memory_trigger and self._long_term_memory and self._memory_trigger.should_trigger(step, len(execution.steps)): + memory_path = await self._long_term_memory.extract_and_save( + execution.steps + ) + if memory_path and self._cli_console: + self._cli_console.print( + f"[Long-term Memory] Extracted and saved to: {memory_path}", + color="cyan", + ) if execution.agent_state == AgentState.COMPLETED: break step_number += 1 @@ -213,7 +248,11 @@ async def _run_llm_step( step.state = AgentStepState.THINKING self._update_cli_console(step, execution) # Get LLM response - llm_response = self._llm_client.chat(messages, self._model_config, self._tools) + # Optional memory-based context compression + effective_messages = messages + if self._long_term_memory and len(messages) > 20: + effective_messages = self.inject_memory_into_messages(messages) + llm_response = self._llm_client.chat(effective_messages, self._model_config, self._tools) step.llm_response = llm_response # Display step with LLM response @@ -350,3 +389,38 @@ async def _tool_call_handler( messages.append(LLMMessage(role="assistant", content=reflection)) return messages + + def inject_memory_into_messages( + self, messages: list[LLMMessage], keep_recent: int = 4 + ) -> list[LLMMessage]: + """Replace older messages with a compressed memory summary. + + Keeps the system message, the memory summary message, and the last + `keep_recent` messages. + """ + if not self._long_term_memory: + return messages + + memory_msg = self._long_term_memory.build_memory_message() + if memory_msg is None: + return messages + + # Always keep system message (index 0) + result = [messages[0]] + + # Add memory summary + result.append(memory_msg) + + # Keep the most recent messages + if len(messages) > keep_recent: + result.extend(messages[-keep_recent:]) + else: + result.extend(messages[1:]) + + return result + + async def extract_memory_now(self) -> str | None: + """Manually trigger memory extraction. Returns the path to the saved Markdown file.""" + if not self._long_term_memory or not self._current_execution: + return None + return await self._long_term_memory.extract_and_save(self._current_execution.steps) diff --git a/trae_agent/cli.py b/trae_agent/cli.py index 3c20b8204..b847eef42 100644 --- a/trae_agent/cli.py +++ b/trae_agent/cli.py @@ -50,6 +50,13 @@ def resolve_config_file(config_file: str) -> str: return config_file +def _generate_session_id() -> str: + """Generate a session ID based on the current timestamp.""" + from datetime import datetime + + return f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + def check_docker(timeout=3): # 1) Check whether the docker CLI is installed if shutil.which("docker") is None: @@ -189,6 +196,19 @@ def cli(): help="Type of agent to use (trae_agent)", default="trae_agent", ) +@click.option( + "--extract-memory", + is_flag=True, + default=False, + help="Extract long-term memory after task execution", +) +@click.option( + "--memory", + "memory_path", + type=click.Path(exists=True, dir_okay=False), + default=None, + help="Path to a .md memory file to preload into the session", +) def run( task: str | None, file_path: str | None, @@ -204,6 +224,8 @@ def run( trajectory_file: str | None = None, console_type: str | None = "simple", agent_type: str | None = "trae_agent", + extract_memory: bool = False, + memory_path: str | None = None, # --- Add Docker Mode --- docker_image: str | None = None, docker_container_id: str | None = None, @@ -349,6 +371,8 @@ def run( ) sys.exit(1) + session_id = _generate_session_id() + agent = Agent( agent_type, config, @@ -356,8 +380,13 @@ def run( cli_console, docker_config=docker_config, docker_keep=docker_keep, + session_id=session_id, + memory_path=memory_path, ) + if memory_path and agent.agent.long_term_memory: + console.print(f"[cyan]Preloaded memory from: {memory_path}[/cyan]") + if not docker_config: try: os.chdir(working_dir) @@ -383,6 +412,14 @@ def run( console.print(f"\n[green]Trajectory saved to: {agent.trajectory_file}[/green]") + # Extract long-term memory if requested + if extract_memory and agent.agent.long_term_memory: + memory_path = asyncio.run(agent.agent.extract_memory_now()) + if memory_path: + console.print(f"[cyan]Long-term memory saved to: {memory_path}[/cyan]") + else: + console.print("[yellow]No memory could be extracted from this execution.[/yellow]") + except KeyboardInterrupt: console.print("\n[yellow]Task execution interrupted by user[/yellow]") console.print(f"[blue]Partial trajectory saved to: {agent.trajectory_file}[/blue]") @@ -437,6 +474,13 @@ def run( help="Type of agent to use (trae_agent)", default="trae_agent", ) +@click.option( + "--memory", + "memory_path", + type=click.Path(exists=True, dir_okay=False), + default=None, + help="Path to a .md memory file to preload into the session", +) def interactive( provider: str | None = None, model: str | None = None, @@ -447,6 +491,7 @@ def interactive( trajectory_file: str | None = None, console_type: str | None = "simple", agent_type: str | None = "trae_agent", + memory_path: str | None = None, ): """ This function starts an interactive session with Trae Agent. @@ -492,7 +537,11 @@ def interactive( sys.exit(1) # Create agent - agent = Agent(agent_type, config, trajectory_file, cli_console) + session_id = _generate_session_id() + agent = Agent(agent_type, config, trajectory_file, cli_console, session_id=session_id, memory_path=memory_path) + + if memory_path and agent.agent.long_term_memory: + console.print(f"[cyan]Preloaded memory from: {memory_path}[/cyan]") # Get the actual trajectory file path (in case it was auto-generated) trajectory_file = agent.trajectory_file @@ -722,6 +771,182 @@ def tools(): console.print(tools_table) +@cli.command() +@click.option( + "--trajectory-file", + "-t", + required=True, + help="Path to a trajectory JSON file to extract memory from", +) +@click.option( + "--output-dir", + "-o", + default="memory/", + help="Output directory for the memory Markdown file", +) +@click.option( + "--config-file", + help="Path to configuration file (for model settings)", + default="trae_config.yaml", + envvar="TRAE_CONFIG_FILE", +) +def memory( + trajectory_file: str, + output_dir: str, + config_file: str, +): + """Extract long-term memory from a trajectory file.""" + import json + + from trae_agent.agent.agent_basics import AgentStep, AgentStepState + from trae_agent.tools.base import ToolCall, ToolResult + from trae_agent.utils.llm_clients.llm_basics import LLMResponse, LLMUsage + from trae_agent.utils.long_term_memory import LongTermMemory + + # Load trajectory JSON + traj_path = Path(trajectory_file) + if not traj_path.exists(): + console.print(f"[red]Error: Trajectory file not found: {traj_path}[/red]") + sys.exit(1) + + with open(traj_path, "r", encoding="utf-8") as f: + traj_data = json.load(f) + + # Rebuild AgentStep objects from trajectory data + steps: list[AgentStep] = [] + for step_data in traj_data.get("agent_steps", []): + llm_response = None + resp_data = step_data.get("llm_response") + if resp_data: + usage = None + usage_data = resp_data.get("usage") + if usage_data: + usage = LLMUsage( + input_tokens=usage_data.get("input_tokens", 0), + output_tokens=usage_data.get("output_tokens", 0), + ) + tool_calls = None + tc_data = resp_data.get("tool_calls") + if tc_data: + tool_calls = [ + ToolCall( + call_id=tc.get("call_id", ""), + name=tc.get("name", ""), + arguments=tc.get("arguments", {}), + ) + for tc in tc_data + ] + llm_response = LLMResponse( + content=resp_data.get("content", ""), + model=resp_data.get("model", ""), + finish_reason=resp_data.get("finish_reason"), + usage=usage, + tool_calls=tool_calls, + ) + + tool_results = None + tr_data = step_data.get("tool_results") + if tr_data: + tool_results = [ + ToolResult( + call_id=tr.get("call_id", ""), + name=tr.get("name", ""), + success=tr.get("success", False), + result=tr.get("result"), + error=tr.get("error"), + ) + for tr in tr_data + ] + + step = AgentStep( + step_number=step_data.get("step_number", 0), + state=AgentStepState(step_data.get("state", "completed")), + llm_response=llm_response, + tool_results=tool_results, + reflection=step_data.get("reflection"), + error=step_data.get("error"), + ) + steps.append(step) + + if not steps: + console.print("[yellow]No agent steps found in the trajectory file.[/yellow]") + sys.exit(1) + + # Load config for model settings + config_file = resolve_config_file(config_file) + try: + config = Config.create(config_file=config_file) + fallback_model = config.trae_agent.model if config.trae_agent else None + except Exception: + fallback_model = None + + if not fallback_model: + console.print("[red]Error: No model configuration found. Check your config file.[/red]") + sys.exit(1) + + # Create LongTermMemory and extract + from trae_agent.utils.config import LongTermMemoryConfig + + ltm_config = LongTermMemoryConfig( + enabled=True, + output_dir=output_dir, + model=fallback_model, + ) + ltm = LongTermMemory(config=ltm_config, fallback_model=fallback_model) + ltm.set_task(traj_data.get("task", "Unknown")) + + console.print(f"[blue]Extracting memory from {len(steps)} steps...[/blue]") + + memory_path = asyncio.run(ltm.extract_and_save(steps)) + if memory_path: + console.print(f"[green]Long-term memory saved to: {memory_path}[/green]") + # Print preview + md_content = Path(memory_path).read_text(encoding="utf-8") + from rich.markdown import Markdown as RichMarkdown + + console.print(Panel(RichMarkdown(md_content), title="Memory Preview", border_style="cyan")) + else: + console.print("[yellow]No memory could be extracted from this trajectory.[/yellow]") + + +@cli.command("memory-list") +@click.option( + "--memory-dir", + default="memory/", + help="Directory containing memory files and index", +) +def memory_list(memory_dir: str): + """List session-memory associations from the memory index.""" + from trae_agent.utils.long_term_memory import LongTermMemory + + index_path = os.path.join(memory_dir, "index.json") + index = LongTermMemory.query_index(index_path) + + sessions = index.get("sessions", {}) + if not sessions: + console.print("[yellow]No memory index found or no sessions recorded.[/yellow]") + return + + table = Table(title="Memory Index — Session Associations") + table.add_column("Session ID", style="cyan") + table.add_column("Task Name", style="green") + table.add_column("Memory Files", style="magenta") + table.add_column("Trajectory File", style="blue") + table.add_column("Created At", style="yellow") + + for session_id, entry in sessions.items(): + memory_files = "\n".join(entry.get("memory_files", [])) + table.add_row( + session_id, + entry.get("task_name", "Unknown"), + memory_files, + entry.get("trajectory_file", ""), + entry.get("created_at", ""), + ) + + console.print(table) + + def main(): """Main entry point for the CLI.""" cli() diff --git a/trae_agent/utils/cli/cli_console.py b/trae_agent/utils/cli/cli_console.py index efcdf8171..84504fe9d 100644 --- a/trae_agent/utils/cli/cli_console.py +++ b/trae_agent/utils/cli/cli_console.py @@ -122,6 +122,10 @@ def set_lakeview(self, lakeview_config: LakeviewConfig | None = None): else: self.lake_view = None + def set_memory_doc(self, doc: object | None): # noqa: B027 + """Set the memory document for visualization. Subclasses can override for custom display.""" + pass + def generate_agent_step_table(agent_step: AgentStep) -> Table: """Log an agent step to the console.""" diff --git a/trae_agent/utils/cli/simple_console.py b/trae_agent/utils/cli/simple_console.py index a613105b3..e8faf1cd2 100644 --- a/trae_agent/utils/cli/simple_console.py +++ b/trae_agent/utils/cli/simple_console.py @@ -20,6 +20,7 @@ generate_agent_step_table, ) from trae_agent.utils.config import LakeviewConfig +from trae_agent.utils.long_term_memory import MemoryDocument class SimpleCLIConsole(CLIConsole): @@ -82,6 +83,10 @@ async def start(self): if self.lake_view and self.agent_execution: await self._print_lakeview_summary() + # Print memory summary if available + if hasattr(self, "_memory_doc") and self._memory_doc: + self._print_memory_summary() + # Print execution summary if self.agent_execution: self._print_execution_summary() @@ -107,6 +112,30 @@ def _print_step_update( self.console.print(table) + def set_memory_doc(self, doc: MemoryDocument | None): + """Set the memory document for visualization.""" + self._memory_doc = doc + + def _print_memory_summary(self): + """Print long-term memory summary.""" + doc = getattr(self, "_memory_doc", None) + if not doc: + return + + self.console.print("\n" + "=" * 60) + self.console.print("[bold cyan]Long-term Memory[/bold cyan]") + self.console.print("=" * 60) + + for section in doc.sections: + self.console.print( + Panel( + f"[bold]Problem:[/bold] {section.problem}\n[bold]Conclusion:[/bold] {section.conclusion}", + title=section.heading(), + border_style="cyan", + width=80, + ) + ) + async def _print_lakeview_summary(self): """Print lakeview summary of all completed steps.""" self.console.print("\n" + "=" * 60) diff --git a/trae_agent/utils/config.py b/trae_agent/utils/config.py index d95026901..951e5ba8d 100644 --- a/trae_agent/utils/config.py +++ b/trae_agent/utils/config.py @@ -155,6 +155,17 @@ class AgentConfig: tools: list[str] +@dataclass +class LongTermMemoryConfig: + """Long-term memory configuration.""" + + enabled: bool = False + trigger_type: str = "manual" # "manual" | "periodic" + periodic_interval: int = 10 # steps between auto-extractions + output_dir: str = "memory/" # directory for Markdown files + model: ModelConfig | None = None # optional separate model; falls back to agent model + + @dataclass class TraeAgentConfig(AgentConfig): """ @@ -162,6 +173,7 @@ class TraeAgentConfig(AgentConfig): """ enable_lakeview: bool = True + long_term_memory: LongTermMemoryConfig | None = None tools: list[str] = field( default_factory=lambda: [ "bash", @@ -271,6 +283,20 @@ def create( } allow_mcp_servers = yaml_config.get("allow_mcp_servers", []) + # Parse long_term_memory config + ltm_yaml = yaml_config.get("long_term_memory", None) + ltm_config: LongTermMemoryConfig | None = None + if ltm_yaml is not None: + ltm_model_name = ltm_yaml.get("model", None) + ltm_model = config_models.get(ltm_model_name) if ltm_model_name else None + ltm_config = LongTermMemoryConfig( + enabled=ltm_yaml.get("enabled", False), + trigger_type=ltm_yaml.get("trigger_type", "manual"), + periodic_interval=ltm_yaml.get("periodic_interval", 10), + output_dir=ltm_yaml.get("output_dir", "memory/"), + model=ltm_model, + ) + # Parse agents agents = yaml_config.get("agents", None) if agents is not None and len(agents.keys()) > 0: @@ -286,6 +312,7 @@ def create( case "trae_agent": trae_agent_config = TraeAgentConfig( **agent_config, + long_term_memory=ltm_config, mcp_servers_config=mcp_servers_config, allow_mcp_servers=allow_mcp_servers, ) diff --git a/trae_agent/utils/long_term_memory.py b/trae_agent/utils/long_term_memory.py new file mode 100644 index 000000000..5f29cf3ff --- /dev/null +++ b/trae_agent/utils/long_term_memory.py @@ -0,0 +1,330 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +"""Long-term memory extraction and visualization for trae-agent.""" + +import json +import re +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path + +from trae_agent.agent.agent_basics import AgentStep +from trae_agent.utils.config import LongTermMemoryConfig +from trae_agent.utils.llm_clients.llm_basics import LLMMessage +from trae_agent.utils.llm_clients.llm_client import LLMClient + +MEMORY_EXTRACTOR_PROMPT = """ +Given the following execution steps of an AI agent solving a task, extract the key "problem" and "conclusion" for each logical group of steps. + +A "problem" describes what the agent was trying to solve or figure out. +A "conclusion" describes what the agent discovered, decided, or accomplished. + +Group consecutive steps that address the same sub-problem into a single section. +Number the sections sequentially using the step range (e.g., steps="1-3", steps="4-5"). + +Output format — repeat this block for each group: + +... +... + + +Be concise. Each problem and conclusion should be at most 2 sentences. +Focus on facts: file paths, function names, root causes, specific changes made. +Do not include any other commentary. +""" + +group_re = re.compile(r'\s*(.*?)\s*(.*?)\s*', re.DOTALL) + + +@dataclass +class MemorySection: + """A single section of extracted memory.""" + + step_range: str # e.g. "1-3" + problem: str + conclusion: str + + def heading(self) -> str: + return f"Step {self.step_range}" + + +@dataclass +class MemoryDocument: + """A complete memory document extracted from agent execution.""" + + task_name: str + session_id: str = "" + sections: list[MemorySection] = field(default_factory=list) + created_at: str = "" + step_count: int = 0 + + def to_markdown(self) -> str: + """Render the full document as Markdown.""" + lines = [f"# Long-term Memory — Task: {self.task_name}", ""] + if self.session_id: + lines.append(f"Session: {self.session_id}") + lines.append("") + if self.created_at: + lines.append(f"Generated: {self.created_at} | Steps: {self.step_count}") + lines.append("") + for section in self.sections: + lines.append(f"## {section.heading()}") + lines.append(f"**Problem**: {section.problem}") + lines.append(f"**Conclusion**: {section.conclusion}") + lines.append("") + return "\n".join(lines) + + @classmethod + def from_markdown(cls, markdown_text: str) -> "MemoryDocument": + """Parse a Markdown file back into a MemoryDocument.""" + task_match = re.search(r"# Long-term Memory — Task: (.+)", markdown_text) + task_name = task_match.group(1).strip() if task_match else "Unknown" + + session_match = re.search(r"^Session: (.+)$", markdown_text, re.MULTILINE) + session_id = session_match.group(1).strip() if session_match else "" + + sections: list[MemorySection] = [] + # Match ## Step N-M blocks + section_pattern = re.compile( + r"## Step ([^\n]+)\n\*\*Problem\*\*: ([^\n]+)\n\*\*Conclusion\*\*: ([^\n]+)", re.DOTALL + ) + for match in section_pattern.finditer(markdown_text): + sections.append( + MemorySection( + step_range=match.group(1).strip(), + problem=match.group(2).strip(), + conclusion=match.group(3).strip(), + ) + ) + + return cls(task_name=task_name, session_id=session_id, sections=sections) + + +class LongTermMemory: + """Long-term memory extraction and management system.""" + + def __init__(self, config: LongTermMemoryConfig, fallback_model): + from trae_agent.utils.config import ModelConfig + + model = config.model if config.model else fallback_model + self._llm_client = LLMClient(model) + self._model_config: ModelConfig = model + self._config = config + self._sections: list[MemorySection] = [] + self._preloaded_sections: list[MemorySection] = [] + self._session_id: str = "" + self._output_dir = Path(config.output_dir) + self._output_dir.mkdir(parents=True, exist_ok=True) + self._task_name: str = "" + + def set_task(self, task_name: str): + """Called at task start.""" + self._task_name = task_name + self._sections = [] + + def set_session_id(self, session_id: str) -> None: + """Set the session ID for this memory instance.""" + self._session_id = session_id + + def load_memory(self, path: str) -> None: + """Load memory sections from a previously saved .md file. + + The loaded sections are stored in _preloaded_sections and persist + across set_task() calls, providing cross-session context. + """ + filepath = Path(path) + if not filepath.exists(): + raise FileNotFoundError(f"Memory file not found: {path}") + markdown_text = filepath.read_text(encoding="utf-8") + doc = MemoryDocument.from_markdown(markdown_text) + if doc.sections: + self._preloaded_sections = doc.sections + + def _agent_step_str(self, agent_step: AgentStep) -> str | None: + """Convert an AgentStep to a string for the LLM.""" + if agent_step.llm_response is None: + return None + + content = agent_step.llm_response.content.strip() + + tool_calls_content = "" + if agent_step.llm_response.tool_calls is not None: + tool_calls_content = "\n".join( + f"[`{tool_call.name}`] `{tool_call.arguments}`" + for tool_call in agent_step.llm_response.tool_calls + ) + tool_calls_content = tool_calls_content.strip() + content = f"{content}\n\nTool calls:\n{tool_calls_content}" + + if agent_step.tool_results: + results_content = "\n".join( + f"[{r.name}] success={r.success}: {r.result or r.error or ''}" + for r in agent_step.tool_results + if r + ) + if results_content: + content = f"{content}\n\nTool results:\n{results_content}" + + return content + + async def extract_memory(self, steps: list[AgentStep]) -> MemoryDocument | None: + """Extract memory from agent steps using LLM.""" + # Build step text + step_texts: list[str] = [] + for step in steps: + step_str = self._agent_step_str(step) + if step_str: + step_texts.append(f"\n{step_str}\n") + + if not step_texts: + return None + + steps_formatted = "\n\n".join(step_texts) + + # Truncate if too long + if len(steps_formatted) > 300_000: + steps_formatted = steps_formatted[-300_000:] + + llm_messages = [ + LLMMessage( + role="user", + content=f"Below are the execution steps of an AI agent:\n\n{steps_formatted}", + ), + LLMMessage(role="assistant", content="I understand."), + LLMMessage(role="user", content=MEMORY_EXTRACTOR_PROMPT), + ] + + self._model_config.temperature = 0.1 + llm_response = self._llm_client.chat( + model_config=self._model_config, + messages=llm_messages, + reuse_history=False, + ) + + content = llm_response.content.strip() + + # Retry if parsing fails + retry = 0 + while retry < 10 and not group_re.search(content): + retry += 1 + llm_response = self._llm_client.chat( + model_config=self._model_config, + messages=llm_messages, + reuse_history=False, + ) + content = llm_response.content.strip() + + # Parse response + sections: list[MemorySection] = [] + for match in group_re.finditer(content): + sections.append( + MemorySection( + step_range=match.group(1).strip(), + problem=match.group(2).strip(), + conclusion=match.group(3).strip(), + ) + ) + + if not sections: + return None + + self._sections = sections + + return MemoryDocument( + task_name=self._task_name, + sections=sections, + created_at=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + step_count=len(steps), + ) + + async def extract_and_save(self, steps: list[AgentStep]) -> str | None: + """Extract memory and save to Markdown file. Returns the file path.""" + doc = await self.extract_memory(steps) + if doc is None: + return None + self._sections = doc.sections + doc.session_id = self._session_id + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"memory_{timestamp}.md" + filepath = self._output_dir / filename + filepath.write_text(doc.to_markdown(), encoding="utf-8") + self._update_index(str(filepath)) + return str(filepath) + + def build_memory_message(self) -> LLMMessage | None: + """Build an LLMMessage containing the compressed memory summary.""" + all_sections = self._preloaded_sections + self._sections + if not all_sections: + return None + content = "# Long-term Memory Summary\n\n" + content += "The following is a compressed summary of the agent's previous execution. " + content += "Use this as context instead of the full conversation history.\n\n" + if self._preloaded_sections: + content += "## Context from Previous Sessions\n\n" + for section in self._preloaded_sections: + content += f"### {section.heading()}\n" + content += f"**Problem**: {section.problem}\n" + content += f"**Conclusion**: {section.conclusion}\n\n" + if self._sections: + content += "## Context from Current Session\n\n" + for section in self._sections: + content += f"### {section.heading()}\n" + content += f"**Problem**: {section.problem}\n" + content += f"**Conclusion**: {section.conclusion}\n\n" + return LLMMessage(role="user", content=content) + + # --- Memory index management --- + + def _index_path(self) -> Path: + """Return the path to the memory index file.""" + return self._output_dir / "index.json" + + def _load_index(self) -> dict: + """Load the memory index from disk, or return empty structure.""" + path = self._index_path() + if path.exists(): + return json.loads(path.read_text(encoding="utf-8")) + return {"version": 1, "sessions": {}} + + def _save_index(self, index: dict) -> None: + """Save the memory index to disk.""" + path = self._index_path() + path.write_text(json.dumps(index, indent=2, ensure_ascii=False), encoding="utf-8") + + def _update_index(self, memory_file_path: str) -> None: + """Update the index with the current session's memory file.""" + if not self._session_id: + return + index = self._load_index() + sessions = index["sessions"] + if self._session_id not in sessions: + sessions[self._session_id] = { + "task_name": self._task_name, + "memory_files": [memory_file_path], + "trajectory_file": "", + "created_at": datetime.now().isoformat(), + } + else: + entry = sessions[self._session_id] + if memory_file_path not in entry["memory_files"]: + entry["memory_files"].append(memory_file_path) + self._save_index(index) + + def set_trajectory_file(self, trajectory_file: str) -> None: + """Record the trajectory file path for this session in the index.""" + if not self._session_id: + return + index = self._load_index() + sessions = index["sessions"] + if self._session_id in sessions: + sessions[self._session_id]["trajectory_file"] = trajectory_file + self._save_index(index) + + @classmethod + def query_index(cls, index_path: str) -> dict: + """Read and return the memory index. Used by CLI commands.""" + path = Path(index_path) + if not path.exists(): + return {"version": 1, "sessions": {}} + return json.loads(path.read_text(encoding="utf-8")) diff --git a/trae_agent/utils/memory_trigger.py b/trae_agent/utils/memory_trigger.py new file mode 100644 index 000000000..645ca0ae8 --- /dev/null +++ b/trae_agent/utils/memory_trigger.py @@ -0,0 +1,68 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +"""Memory trigger factory for long-term memory extraction.""" + +from abc import ABC, abstractmethod + +from trae_agent.agent.agent_basics import AgentStep + + +class MemoryTrigger(ABC): + """Abstract base class for memory extraction triggers.""" + + @abstractmethod + def should_trigger(self, step: AgentStep, steps_completed: int) -> bool: + """Decide whether to trigger memory extraction. + + Args: + step: The step that just completed. + steps_completed: Total number of completed steps so far. + + Returns: + True if memory extraction should be triggered now. + """ + pass + + @abstractmethod + def trigger_type_name(self) -> str: + """Return a human-readable name for this trigger type.""" + pass + + +class ManualMemoryTrigger(MemoryTrigger): + """Only triggers when explicitly called by the user.""" + + def should_trigger(self, step: AgentStep, steps_completed: int) -> bool: + return False + + def trigger_type_name(self) -> str: + return "manual" + + +class PeriodicMemoryTrigger(MemoryTrigger): + """Triggers every N steps.""" + + def __init__(self, interval: int = 10): + self._interval = interval + self._last_triggered_at: int = 0 + + def should_trigger(self, step: AgentStep, steps_completed: int) -> bool: + if steps_completed > 0 and steps_completed % self._interval == 0 and steps_completed != self._last_triggered_at: + self._last_triggered_at = steps_completed + return True + return False + + def trigger_type_name(self) -> str: + return f"periodic(every {self._interval} steps)" + + +def create_memory_trigger(trigger_type: str, periodic_interval: int = 10) -> MemoryTrigger: + """Factory: create the appropriate trigger based on config.""" + match trigger_type: + case "manual": + return ManualMemoryTrigger() + case "periodic": + return PeriodicMemoryTrigger(interval=periodic_interval) + case _: + raise ValueError(f"Unknown memory trigger_type: {trigger_type}") diff --git a/trae_config.yaml.example b/trae_config.yaml.example index 21f9d4d20..a471ea695 100644 --- a/trae_config.yaml.example +++ b/trae_config.yaml.example @@ -18,6 +18,13 @@ mcp_servers: lakeview: model: lakeview_model +# long_term_memory: +# enabled: true +# trigger_type: periodic # "manual" or "periodic" +# periodic_interval: 10 # steps between auto-extractions (for periodic trigger) +# output_dir: memory/ # directory for Markdown memory files +# # model: ltm_model # optional separate model; falls back to agent model + model_providers: anthropic: api_key: your_anthropic_api_key