diff --git a/tests/utils/test_long_term_memory.py b/tests/utils/test_long_term_memory.py
new file mode 100644
index 000000000..a904d8464
--- /dev/null
+++ b/tests/utils/test_long_term_memory.py
@@ -0,0 +1,659 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+# SPDX-License-Identifier: MIT
+
+import unittest
+from pathlib import Path
+from unittest.mock import MagicMock
+
+from trae_agent.agent.agent_basics import AgentStep, AgentStepState
+from trae_agent.utils.config import Config, LongTermMemoryConfig, ModelConfig, ModelProvider
+from trae_agent.utils.llm_clients.llm_basics import LLMMessage, LLMResponse
+from trae_agent.utils.long_term_memory import (
+ LongTermMemory,
+ MemoryDocument,
+ MemorySection,
+)
+from trae_agent.utils.memory_trigger import (
+ ManualMemoryTrigger,
+ PeriodicMemoryTrigger,
+ create_memory_trigger,
+)
+
+
+class TestMemoryDocument(unittest.TestCase):
+ def test_to_markdown(self):
+ doc = MemoryDocument(
+ task_name="Fix login bug",
+ sections=[
+ MemorySection(step_range="1-3", problem="Login fails", conclusion="Regex issue"),
+ MemorySection(step_range="4-5", problem="Fix regex", conclusion="Used re.escape"),
+ ],
+ created_at="2025-01-01 00:00:00",
+ step_count=5,
+ )
+ md = doc.to_markdown()
+ self.assertIn("# Long-term Memory — Task: Fix login bug", md)
+ self.assertIn("## Step 1-3", md)
+ self.assertIn("**Problem**: Login fails", md)
+ self.assertIn("**Conclusion**: Regex issue", md)
+ self.assertIn("## Step 4-5", md)
+ self.assertIn("**Problem**: Fix regex", md)
+ self.assertIn("**Conclusion**: Used re.escape", md)
+ self.assertIn("Generated: 2025-01-01 00:00:00 | Steps: 5", md)
+
+ def test_from_markdown_round_trip(self):
+ doc = MemoryDocument(
+ task_name="Fix login bug",
+ sections=[
+ MemorySection(step_range="1-3", problem="Login fails", conclusion="Regex issue"),
+ ],
+ created_at="2025-01-01 00:00:00",
+ step_count=3,
+ )
+ md = doc.to_markdown()
+ parsed = MemoryDocument.from_markdown(md)
+ self.assertEqual(parsed.task_name, "Fix login bug")
+ self.assertEqual(len(parsed.sections), 1)
+ self.assertEqual(parsed.sections[0].step_range, "1-3")
+ self.assertEqual(parsed.sections[0].problem, "Login fails")
+ self.assertEqual(parsed.sections[0].conclusion, "Regex issue")
+
+ def test_from_markdown_unknown_task(self):
+ parsed = MemoryDocument.from_markdown("# unrelated content")
+ self.assertEqual(parsed.task_name, "Unknown")
+ self.assertEqual(len(parsed.sections), 0)
+
+ def test_memory_section_heading(self):
+ section = MemorySection(step_range="1-5", problem="x", conclusion="y")
+ self.assertEqual(section.heading(), "Step 1-5")
+
+
+class TestMemoryTrigger(unittest.TestCase):
+ def test_manual_trigger_never_fires(self):
+ trigger = ManualMemoryTrigger()
+ step = AgentStep(step_number=1, state=AgentStepState.COMPLETED)
+ self.assertFalse(trigger.should_trigger(step, 1))
+ self.assertFalse(trigger.should_trigger(step, 100))
+ self.assertEqual(trigger.trigger_type_name(), "manual")
+
+ def test_periodic_trigger_fires_at_interval(self):
+ trigger = PeriodicMemoryTrigger(interval=5)
+ step = AgentStep(step_number=5, state=AgentStepState.COMPLETED)
+ self.assertTrue(trigger.should_trigger(step, 5))
+ # Should not fire again at same count
+ self.assertFalse(trigger.should_trigger(step, 5))
+
+ def test_periodic_trigger_does_not_fire_between_intervals(self):
+ trigger = PeriodicMemoryTrigger(interval=10)
+ step = AgentStep(step_number=3, state=AgentStepState.COMPLETED)
+ self.assertFalse(trigger.should_trigger(step, 3))
+ self.assertFalse(trigger.should_trigger(step, 7))
+
+ def test_periodic_trigger_multiple_firings(self):
+ trigger = PeriodicMemoryTrigger(interval=3)
+ step = AgentStep(step_number=3, state=AgentStepState.COMPLETED)
+ self.assertTrue(trigger.should_trigger(step, 3))
+ self.assertFalse(trigger.should_trigger(step, 3))
+ self.assertTrue(trigger.should_trigger(step, 6))
+ self.assertFalse(trigger.should_trigger(step, 6))
+
+ def test_periodic_trigger_type_name(self):
+ trigger = PeriodicMemoryTrigger(interval=5)
+ self.assertEqual(trigger.trigger_type_name(), "periodic(every 5 steps)")
+
+ def test_create_memory_trigger_manual(self):
+ trigger = create_memory_trigger("manual")
+ self.assertIsInstance(trigger, ManualMemoryTrigger)
+
+ def test_create_memory_trigger_periodic(self):
+ trigger = create_memory_trigger("periodic", periodic_interval=7)
+ self.assertIsInstance(trigger, PeriodicMemoryTrigger)
+ self.assertEqual(trigger.trigger_type_name(), "periodic(every 7 steps)")
+
+ def test_create_memory_trigger_unknown(self):
+ with self.assertRaises(ValueError):
+ create_memory_trigger("unknown")
+
+
+class TestLongTermMemoryConfig(unittest.TestCase):
+ def test_default_values(self):
+ config = LongTermMemoryConfig()
+ self.assertFalse(config.enabled)
+ self.assertEqual(config.trigger_type, "manual")
+ self.assertEqual(config.periodic_interval, 10)
+ self.assertEqual(config.output_dir, "memory/")
+ self.assertIsNone(config.model)
+
+ def test_custom_values(self):
+ config = LongTermMemoryConfig(
+ enabled=True,
+ trigger_type="periodic",
+ periodic_interval=5,
+ output_dir="my_memory/",
+ )
+ self.assertTrue(config.enabled)
+ self.assertEqual(config.trigger_type, "periodic")
+ self.assertEqual(config.periodic_interval, 5)
+ self.assertEqual(config.output_dir, "my_memory/")
+
+ def test_yaml_parsing(self):
+ yaml_str = """
+model_providers:
+ openai:
+ api_key: test-key
+ provider: openai
+models:
+ default_model:
+ model_provider: openai
+ model: gpt-4o
+ max_tokens: 4096
+ temperature: 0.5
+ top_p: 1
+ top_k: 0
+ max_retries: 5
+ parallel_tool_calls: true
+long_term_memory:
+ enabled: true
+ trigger_type: periodic
+ periodic_interval: 5
+ output_dir: my_memory/
+agents:
+ trae_agent:
+ model: default_model
+ max_steps: 50
+ enable_lakeview: false
+ tools:
+ - bash
+"""
+ config = Config.create(config_string=yaml_str)
+ self.assertIsNotNone(config.trae_agent)
+ self.assertIsNotNone(config.trae_agent.long_term_memory)
+ ltm = config.trae_agent.long_term_memory
+ self.assertTrue(ltm.enabled)
+ self.assertEqual(ltm.trigger_type, "periodic")
+ self.assertEqual(ltm.periodic_interval, 5)
+ self.assertEqual(ltm.output_dir, "my_memory/")
+
+ def test_yaml_parsing_disabled(self):
+ yaml_str = """
+model_providers:
+ openai:
+ api_key: test-key
+ provider: openai
+models:
+ default_model:
+ model_provider: openai
+ model: gpt-4o
+ max_tokens: 4096
+ temperature: 0.5
+ top_p: 1
+ top_k: 0
+ max_retries: 5
+ parallel_tool_calls: true
+long_term_memory:
+ enabled: false
+agents:
+ trae_agent:
+ model: default_model
+ max_steps: 50
+ enable_lakeview: false
+ tools:
+ - bash
+"""
+ config = Config.create(config_string=yaml_str)
+ self.assertIsNotNone(config.trae_agent.long_term_memory)
+ self.assertFalse(config.trae_agent.long_term_memory.enabled)
+
+
+class TestBuildMemoryMessage(unittest.TestCase):
+ def test_build_message_with_sections(self):
+ model = ModelConfig(
+ model="gpt-4o",
+ model_provider=ModelProvider(api_key="test", provider="openai"),
+ max_tokens=4096,
+ temperature=0.5,
+ top_p=1,
+ top_k=0,
+ max_retries=5,
+ parallel_tool_calls=True,
+ )
+ ltm = LongTermMemory(
+ config=LongTermMemoryConfig(enabled=True),
+ fallback_model=model,
+ )
+ ltm._sections = [
+ MemorySection(step_range="1-3", problem="Login fails", conclusion="Regex issue"),
+ ]
+ msg = ltm.build_memory_message()
+ self.assertIsNotNone(msg)
+ self.assertIn("Login fails", msg.content)
+ self.assertIn("Regex issue", msg.content)
+ self.assertEqual(msg.role, "user")
+
+ def test_build_message_no_sections(self):
+ model = ModelConfig(
+ model="gpt-4o",
+ model_provider=ModelProvider(api_key="test", provider="openai"),
+ max_tokens=4096,
+ temperature=0.5,
+ top_p=1,
+ top_k=0,
+ max_retries=5,
+ parallel_tool_calls=True,
+ )
+ ltm = LongTermMemory(
+ config=LongTermMemoryConfig(enabled=True),
+ fallback_model=model,
+ )
+ msg = ltm.build_memory_message()
+ self.assertIsNone(msg)
+
+
+class TestInjectMemoryIntoMessages(unittest.TestCase):
+ def _make_agent(self):
+ from trae_agent.agent.base_agent import BaseAgent
+
+ class ConcreteAgent(BaseAgent):
+ def new_task(self, task, extra_args=None, tool_names=None):
+ pass
+
+ async def cleanup_mcp_clients(self):
+ pass
+
+ agent = ConcreteAgent.__new__(ConcreteAgent)
+ agent._long_term_memory = None
+ agent._memory_trigger = None
+ agent._max_steps = 10
+ agent._tools = []
+ agent._cli_console = None
+ agent._trajectory_recorder = None
+ agent._tool_confirmation_config = MagicMock(enabled=False)
+ agent._tool_confirmation_approved_all = False
+ agent._allowed_command_prefixes = []
+ agent._allowed_tool_names = set()
+ return agent
+
+ def test_inject_preserves_system_and_recent(self):
+ agent = self._make_agent()
+ ltm = MagicMock()
+ ltm.build_memory_message.return_value = LLMMessage(
+ role="user", content="Memory summary"
+ )
+ agent._long_term_memory = ltm
+
+ messages = [LLMMessage(role="system", content="system prompt")]
+ for i in range(25):
+ messages.append(LLMMessage(role="user", content=f"msg {i}"))
+
+ result = agent.inject_memory_into_messages(messages, keep_recent=4)
+ self.assertEqual(result[0].content, "system prompt")
+ self.assertEqual(result[1].content, "Memory summary")
+ self.assertEqual(len(result), 6) # system + memory + 4 recent
+
+ def test_no_injection_without_memory(self):
+ agent = self._make_agent()
+ messages = [LLMMessage(role="system", content="system prompt")]
+ for i in range(25):
+ messages.append(LLMMessage(role="user", content=f"msg {i}"))
+
+ result = agent.inject_memory_into_messages(messages)
+ self.assertEqual(result, messages)
+
+
+class TestExtractMemory(unittest.TestCase):
+ def test_extract_memory_with_mock_llm(self):
+ model = ModelConfig(
+ model="gpt-4o",
+ model_provider=ModelProvider(api_key="test", provider="openai"),
+ max_tokens=4096,
+ temperature=0.5,
+ top_p=1,
+ top_k=0,
+ max_retries=5,
+ parallel_tool_calls=True,
+ )
+ ltm = LongTermMemory(
+ config=LongTermMemoryConfig(enabled=True),
+ fallback_model=model,
+ )
+ ltm.set_task("Fix login bug")
+
+ # Mock the LLM client
+ mock_response = LLMResponse(
+ content='\nLogin fails with special chars\nRegex in auth/validator.py:42 needs escaping\n',
+ model="gpt-4o",
+ )
+ ltm._llm_client = MagicMock()
+ ltm._llm_client.chat = MagicMock(return_value=mock_response)
+
+ steps = [
+ AgentStep(
+ step_number=1,
+ state=AgentStepState.COMPLETED,
+ llm_response=LLMResponse(content="Looking at login code", model="gpt-4o"),
+ ),
+ AgentStep(
+ step_number=2,
+ state=AgentStepState.COMPLETED,
+ llm_response=LLMResponse(content="Found the issue", model="gpt-4o"),
+ ),
+ ]
+
+ import asyncio
+
+ doc = asyncio.run(ltm.extract_memory(steps))
+ self.assertIsNotNone(doc)
+ self.assertEqual(len(doc.sections), 1)
+ self.assertEqual(doc.sections[0].step_range, "1-3")
+ self.assertIn("Login fails", doc.sections[0].problem)
+ self.assertIn("Regex", doc.sections[0].conclusion)
+
+ def test_extract_and_save_creates_file(self):
+ import asyncio
+ import tempfile
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ model = ModelConfig(
+ model="gpt-4o",
+ model_provider=ModelProvider(api_key="test", provider="openai"),
+ max_tokens=4096,
+ temperature=0.5,
+ top_p=1,
+ top_k=0,
+ max_retries=5,
+ parallel_tool_calls=True,
+ )
+ ltm = LongTermMemory(
+ config=LongTermMemoryConfig(enabled=True, output_dir=tmpdir),
+ fallback_model=model,
+ )
+ ltm.set_task("Test task")
+
+ mock_response = LLMResponse(
+ content='\nTest problem\nTest conclusion\n',
+ model="gpt-4o",
+ )
+ ltm._llm_client = MagicMock()
+ ltm._llm_client.chat = MagicMock(return_value=mock_response)
+
+ steps = [
+ AgentStep(
+ step_number=1,
+ state=AgentStepState.COMPLETED,
+ llm_response=LLMResponse(content="Step 1", model="gpt-4o"),
+ ),
+ ]
+
+ path = asyncio.run(ltm.extract_and_save(steps))
+ self.assertIsNotNone(path)
+ import os
+
+ self.assertTrue(os.path.exists(path))
+ with open(path) as f:
+ content = f.read()
+ self.assertIn("Test task", content)
+ self.assertIn("Test problem", content)
+ self.assertIn("Test conclusion", content)
+
+
+if __name__ == "__main__":
+ unittest.main()
+
+
+class TestMemoryDocumentSessionId(unittest.TestCase):
+ def test_to_markdown_with_session_id(self):
+ doc = MemoryDocument(
+ task_name="Fix login bug",
+ session_id="session_20250101_120000",
+ sections=[
+ MemorySection(step_range="1-3", problem="Login fails", conclusion="Regex issue"),
+ ],
+ created_at="2025-01-01 12:00:00",
+ step_count=3,
+ )
+ md = doc.to_markdown()
+ self.assertIn("Session: session_20250101_120000", md)
+
+ def test_to_markdown_without_session_id(self):
+ doc = MemoryDocument(
+ task_name="Fix login bug",
+ sections=[
+ MemorySection(step_range="1-3", problem="Login fails", conclusion="Regex issue"),
+ ],
+ created_at="2025-01-01 12:00:00",
+ step_count=3,
+ )
+ md = doc.to_markdown()
+ self.assertNotIn("Session:", md)
+
+ def test_from_markdown_with_session_id(self):
+ md = """# Long-term Memory — Task: Fix login bug
+Session: session_20250101_120000
+Generated: 2025-01-01 12:00:00 | Steps: 3
+
+## Step 1-3
+**Problem**: Login fails
+**Conclusion**: Regex issue
+"""
+ doc = MemoryDocument.from_markdown(md)
+ self.assertEqual(doc.session_id, "session_20250101_120000")
+
+ def test_from_markdown_without_session_id(self):
+ md = """# Long-term Memory — Task: Fix login bug
+Generated: 2025-01-01 12:00:00 | Steps: 3
+
+## Step 1-3
+**Problem**: Login fails
+**Conclusion**: Regex issue
+"""
+ doc = MemoryDocument.from_markdown(md)
+ self.assertEqual(doc.session_id, "")
+
+
+class TestLoadMemory(unittest.TestCase):
+ def _make_ltm(self, tmpdir: str) -> LongTermMemory:
+ model = ModelConfig(
+ model="gpt-4o",
+ model_provider=ModelProvider(api_key="test", provider="openai"),
+ max_tokens=4096,
+ temperature=0.5,
+ top_p=1,
+ top_k=0,
+ max_retries=5,
+ parallel_tool_calls=True,
+ )
+ return LongTermMemory(
+ config=LongTermMemoryConfig(enabled=True, output_dir=tmpdir),
+ fallback_model=model,
+ )
+
+ def test_load_memory_populates_preloaded_sections(self):
+ import tempfile
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ ltm = self._make_ltm(tmpdir)
+ # Write a memory file
+ md_path = f"{tmpdir}/test_memory.md"
+ with open(md_path, "w") as f:
+ f.write("# Long-term Memory — Task: Previous task\n\n## Step 1-2\n**Problem**: Old problem\n**Conclusion**: Old conclusion\n")
+ ltm.load_memory(md_path)
+ self.assertEqual(len(ltm._preloaded_sections), 1)
+ self.assertEqual(ltm._preloaded_sections[0].problem, "Old problem")
+
+ def test_load_memory_file_not_found(self):
+ import tempfile
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ ltm = self._make_ltm(tmpdir)
+ with self.assertRaises(FileNotFoundError):
+ ltm.load_memory(f"{tmpdir}/nonexistent.md")
+
+ def test_load_memory_empty_sections(self):
+ import tempfile
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ ltm = self._make_ltm(tmpdir)
+ md_path = f"{tmpdir}/empty.md"
+ with open(md_path, "w") as f:
+ f.write("# Long-term Memory — Task: Empty\n\nNo sections here.\n")
+ ltm.load_memory(md_path)
+ self.assertEqual(len(ltm._preloaded_sections), 0)
+
+ def test_preloaded_sections_survive_set_task(self):
+ import tempfile
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ ltm = self._make_ltm(tmpdir)
+ md_path = f"{tmpdir}/test_memory.md"
+ with open(md_path, "w") as f:
+ f.write("# Long-term Memory — Task: Previous task\n\n## Step 1-2\n**Problem**: Old problem\n**Conclusion**: Old conclusion\n")
+ ltm.load_memory(md_path)
+ ltm._sections = [MemorySection(step_range="1-1", problem="Current", conclusion="Current")]
+ ltm.set_task("New task")
+ self.assertEqual(len(ltm._preloaded_sections), 1)
+ self.assertEqual(len(ltm._sections), 0)
+
+ def test_build_memory_message_combines_both(self):
+ import tempfile
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ ltm = self._make_ltm(tmpdir)
+ md_path = f"{tmpdir}/test_memory.md"
+ with open(md_path, "w") as f:
+ f.write("# Long-term Memory — Task: Previous task\n\n## Step 1-2\n**Problem**: Old problem\n**Conclusion**: Old conclusion\n")
+ ltm.load_memory(md_path)
+ ltm._sections = [MemorySection(step_range="1-1", problem="Current problem", conclusion="Current conclusion")]
+ msg = ltm.build_memory_message()
+ self.assertIsNotNone(msg)
+ self.assertIn("Context from Previous Sessions", msg.content)
+ self.assertIn("Context from Current Session", msg.content)
+ self.assertIn("Old problem", msg.content)
+ self.assertIn("Current problem", msg.content)
+
+
+class TestMemoryIndex(unittest.TestCase):
+ def _make_ltm(self, tmpdir: str) -> LongTermMemory:
+ model = ModelConfig(
+ model="gpt-4o",
+ model_provider=ModelProvider(api_key="test", provider="openai"),
+ max_tokens=4096,
+ temperature=0.5,
+ top_p=1,
+ top_k=0,
+ max_retries=5,
+ parallel_tool_calls=True,
+ )
+ return LongTermMemory(
+ config=LongTermMemoryConfig(enabled=True, output_dir=tmpdir),
+ fallback_model=model,
+ )
+
+ def test_update_index_creates_entry(self):
+ import json
+ import tempfile
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ ltm = self._make_ltm(tmpdir)
+ ltm.set_session_id("session_test_001")
+ ltm.set_task("Test task")
+ ltm._update_index(f"{tmpdir}/memory_test.md")
+
+ index = json.loads(Path(f"{tmpdir}/index.json").read_text())
+ self.assertIn("session_test_001", index["sessions"])
+ self.assertEqual(index["sessions"]["session_test_001"]["task_name"], "Test task")
+
+ def test_update_index_appends_memory_file(self):
+ import json
+ import tempfile
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ ltm = self._make_ltm(tmpdir)
+ ltm.set_session_id("session_test_002")
+ ltm.set_task("Test task")
+ ltm._update_index(f"{tmpdir}/memory_1.md")
+ ltm._update_index(f"{tmpdir}/memory_2.md")
+
+ index = json.loads(Path(f"{tmpdir}/index.json").read_text())
+ self.assertEqual(len(index["sessions"]["session_test_002"]["memory_files"]), 2)
+
+ def test_query_index_empty(self):
+ result = LongTermMemory.query_index("/nonexistent/path/index.json")
+ self.assertEqual(result, {"version": 1, "sessions": {}})
+
+ def test_set_trajectory_file(self):
+ import json
+ import tempfile
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ ltm = self._make_ltm(tmpdir)
+ ltm.set_session_id("session_test_003")
+ ltm.set_task("Test task")
+ ltm._update_index(f"{tmpdir}/memory_test.md")
+ ltm.set_trajectory_file("trajectories/trajectory_test.json")
+
+ index = json.loads(Path(f"{tmpdir}/index.json").read_text())
+ self.assertEqual(
+ index["sessions"]["session_test_003"]["trajectory_file"],
+ "trajectories/trajectory_test.json",
+ )
+
+
+class TestCrossSessionIntegration(unittest.TestCase):
+ def test_save_and_reload_memory(self):
+ import asyncio
+ import tempfile
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # First session: extract and save
+ model = ModelConfig(
+ model="gpt-4o",
+ model_provider=ModelProvider(api_key="test", provider="openai"),
+ max_tokens=4096,
+ temperature=0.5,
+ top_p=1,
+ top_k=0,
+ max_retries=5,
+ parallel_tool_calls=True,
+ )
+ ltm1 = LongTermMemory(
+ config=LongTermMemoryConfig(enabled=True, output_dir=tmpdir),
+ fallback_model=model,
+ )
+ ltm1.set_session_id("session_first")
+ ltm1.set_task("First task")
+
+ mock_response = LLMResponse(
+ content='\nFirst problem\nFirst conclusion\n',
+ model="gpt-4o",
+ )
+ ltm1._llm_client = MagicMock()
+ ltm1._llm_client.chat = MagicMock(return_value=mock_response)
+
+ steps = [
+ AgentStep(
+ step_number=1,
+ state=AgentStepState.COMPLETED,
+ llm_response=LLMResponse(content="Step 1", model="gpt-4o"),
+ ),
+ ]
+
+ path = asyncio.run(ltm1.extract_and_save(steps))
+ self.assertIsNotNone(path)
+
+ # Second session: load the memory
+ ltm2 = LongTermMemory(
+ config=LongTermMemoryConfig(enabled=True, output_dir=tmpdir),
+ fallback_model=model,
+ )
+ ltm2.set_session_id("session_second")
+ ltm2.set_task("Second task")
+ ltm2.load_memory(path)
+
+ self.assertEqual(len(ltm2._preloaded_sections), 1)
+ self.assertEqual(ltm2._preloaded_sections[0].problem, "First problem")
+
+ # Verify build_memory_message includes both
+ ltm2._sections = [MemorySection(step_range="1-1", problem="New problem", conclusion="New conclusion")]
+ msg = ltm2.build_memory_message()
+ self.assertIn("First problem", msg.content)
+ self.assertIn("New problem", msg.content)
diff --git a/trae_agent/agent/agent.py b/trae_agent/agent/agent.py
index bbca94f01..f9d57d243 100644
--- a/trae_agent/agent/agent.py
+++ b/trae_agent/agent/agent.py
@@ -20,6 +20,8 @@ def __init__(
cli_console: CLIConsole | None = None,
docker_config: dict | None = None,
docker_keep: bool = True,
+ session_id: str | None = None,
+ memory_path: str | None = None,
):
if isinstance(agent_type, str):
agent_type = AgentType(agent_type)
@@ -56,6 +58,23 @@ def __init__(
self.agent.set_trajectory_recorder(self.trajectory_recorder)
+ # Set up long-term memory
+ if config.trae_agent and config.trae_agent.long_term_memory and config.trae_agent.long_term_memory.enabled:
+ from trae_agent.utils.long_term_memory import LongTermMemory
+
+ ltm = LongTermMemory(
+ config=config.trae_agent.long_term_memory,
+ fallback_model=config.trae_agent.model,
+ )
+ self.agent.set_long_term_memory(ltm)
+
+ # Set session ID and preload memory if provided
+ if self.agent.long_term_memory:
+ if session_id:
+ self.agent.long_term_memory.set_session_id(session_id)
+ if memory_path:
+ self.agent.long_term_memory.load_memory(memory_path)
+
async def run(
self,
task: str,
@@ -64,6 +83,10 @@ async def run(
):
self.agent.new_task(task, extra_args, tool_names)
+ if self.agent.long_term_memory:
+ self.agent.long_term_memory.set_task(task)
+ self.agent.long_term_memory.set_trajectory_file(self.trajectory_file)
+
if self.agent.allow_mcp_servers:
if self.agent.cli_console:
self.agent.cli_console.print("Initialising MCP tools...")
diff --git a/trae_agent/agent/base_agent.py b/trae_agent/agent/base_agent.py
index 01d4fde4c..3d9d44115 100644
--- a/trae_agent/agent/base_agent.py
+++ b/trae_agent/agent/base_agent.py
@@ -18,6 +18,8 @@
from trae_agent.utils.config import AgentConfig, ModelConfig
from trae_agent.utils.llm_clients.llm_basics import LLMMessage, LLMResponse
from trae_agent.utils.llm_clients.llm_client import LLMClient
+from trae_agent.utils.long_term_memory import LongTermMemory
+from trae_agent.utils.memory_trigger import MemoryTrigger
from trae_agent.utils.trajectory_recorder import TrajectoryRecorder
@@ -77,6 +79,11 @@ def __init__(
# Trajectory recorder
self._trajectory_recorder: TrajectoryRecorder | None = None
+ # Long-term memory
+ self._long_term_memory: LongTermMemory | None = None
+ self._memory_trigger: MemoryTrigger | None = None
+ self._current_execution: AgentExecution | None = None
+
# CKG tool-specific: clear the older CKG databases
clear_older_ckg()
@@ -95,6 +102,23 @@ def set_trajectory_recorder(self, recorder: TrajectoryRecorder | None) -> None:
# Also set it on the LLM client
self._llm_client.set_trajectory_recorder(recorder)
+ @property
+ def long_term_memory(self) -> LongTermMemory | None:
+ """Get the long-term memory system for this agent."""
+ return self._long_term_memory
+
+ def set_long_term_memory(self, ltm: LongTermMemory | None) -> None:
+ """Set the long-term memory system and its trigger."""
+ from trae_agent.utils.memory_trigger import create_memory_trigger
+
+ self._long_term_memory = ltm
+ if ltm is not None:
+ self._memory_trigger = create_memory_trigger(
+ ltm._config.trigger_type, ltm._config.periodic_interval
+ )
+ else:
+ self._memory_trigger = None
+
@property
def cli_console(self) -> CLIConsole | None:
"""Get the CLI console for this agent."""
@@ -153,6 +177,7 @@ async def execute_task(self) -> AgentExecution:
start_time = time.time()
execution = AgentExecution(task=self._task, steps=[])
+ self._current_execution = execution
step: AgentStep | None = None
try:
@@ -167,6 +192,16 @@ async def execute_task(self) -> AgentExecution:
await self._finalize_step(
step, messages, execution
) # record trajectory for this step and update the CLI console
+ # Check memory trigger
+ if self._memory_trigger and self._long_term_memory and self._memory_trigger.should_trigger(step, len(execution.steps)):
+ memory_path = await self._long_term_memory.extract_and_save(
+ execution.steps
+ )
+ if memory_path and self._cli_console:
+ self._cli_console.print(
+ f"[Long-term Memory] Extracted and saved to: {memory_path}",
+ color="cyan",
+ )
if execution.agent_state == AgentState.COMPLETED:
break
step_number += 1
@@ -213,7 +248,11 @@ async def _run_llm_step(
step.state = AgentStepState.THINKING
self._update_cli_console(step, execution)
# Get LLM response
- llm_response = self._llm_client.chat(messages, self._model_config, self._tools)
+ # Optional memory-based context compression
+ effective_messages = messages
+ if self._long_term_memory and len(messages) > 20:
+ effective_messages = self.inject_memory_into_messages(messages)
+ llm_response = self._llm_client.chat(effective_messages, self._model_config, self._tools)
step.llm_response = llm_response
# Display step with LLM response
@@ -350,3 +389,38 @@ async def _tool_call_handler(
messages.append(LLMMessage(role="assistant", content=reflection))
return messages
+
+ def inject_memory_into_messages(
+ self, messages: list[LLMMessage], keep_recent: int = 4
+ ) -> list[LLMMessage]:
+ """Replace older messages with a compressed memory summary.
+
+ Keeps the system message, the memory summary message, and the last
+ `keep_recent` messages.
+ """
+ if not self._long_term_memory:
+ return messages
+
+ memory_msg = self._long_term_memory.build_memory_message()
+ if memory_msg is None:
+ return messages
+
+ # Always keep system message (index 0)
+ result = [messages[0]]
+
+ # Add memory summary
+ result.append(memory_msg)
+
+ # Keep the most recent messages
+ if len(messages) > keep_recent:
+ result.extend(messages[-keep_recent:])
+ else:
+ result.extend(messages[1:])
+
+ return result
+
+ async def extract_memory_now(self) -> str | None:
+ """Manually trigger memory extraction. Returns the path to the saved Markdown file."""
+ if not self._long_term_memory or not self._current_execution:
+ return None
+ return await self._long_term_memory.extract_and_save(self._current_execution.steps)
diff --git a/trae_agent/cli.py b/trae_agent/cli.py
index 3c20b8204..b847eef42 100644
--- a/trae_agent/cli.py
+++ b/trae_agent/cli.py
@@ -50,6 +50,13 @@ def resolve_config_file(config_file: str) -> str:
return config_file
+def _generate_session_id() -> str:
+ """Generate a session ID based on the current timestamp."""
+ from datetime import datetime
+
+ return f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
+
+
def check_docker(timeout=3):
# 1) Check whether the docker CLI is installed
if shutil.which("docker") is None:
@@ -189,6 +196,19 @@ def cli():
help="Type of agent to use (trae_agent)",
default="trae_agent",
)
+@click.option(
+ "--extract-memory",
+ is_flag=True,
+ default=False,
+ help="Extract long-term memory after task execution",
+)
+@click.option(
+ "--memory",
+ "memory_path",
+ type=click.Path(exists=True, dir_okay=False),
+ default=None,
+ help="Path to a .md memory file to preload into the session",
+)
def run(
task: str | None,
file_path: str | None,
@@ -204,6 +224,8 @@ def run(
trajectory_file: str | None = None,
console_type: str | None = "simple",
agent_type: str | None = "trae_agent",
+ extract_memory: bool = False,
+ memory_path: str | None = None,
# --- Add Docker Mode ---
docker_image: str | None = None,
docker_container_id: str | None = None,
@@ -349,6 +371,8 @@ def run(
)
sys.exit(1)
+ session_id = _generate_session_id()
+
agent = Agent(
agent_type,
config,
@@ -356,8 +380,13 @@ def run(
cli_console,
docker_config=docker_config,
docker_keep=docker_keep,
+ session_id=session_id,
+ memory_path=memory_path,
)
+ if memory_path and agent.agent.long_term_memory:
+ console.print(f"[cyan]Preloaded memory from: {memory_path}[/cyan]")
+
if not docker_config:
try:
os.chdir(working_dir)
@@ -383,6 +412,14 @@ def run(
console.print(f"\n[green]Trajectory saved to: {agent.trajectory_file}[/green]")
+ # Extract long-term memory if requested
+ if extract_memory and agent.agent.long_term_memory:
+ memory_path = asyncio.run(agent.agent.extract_memory_now())
+ if memory_path:
+ console.print(f"[cyan]Long-term memory saved to: {memory_path}[/cyan]")
+ else:
+ console.print("[yellow]No memory could be extracted from this execution.[/yellow]")
+
except KeyboardInterrupt:
console.print("\n[yellow]Task execution interrupted by user[/yellow]")
console.print(f"[blue]Partial trajectory saved to: {agent.trajectory_file}[/blue]")
@@ -437,6 +474,13 @@ def run(
help="Type of agent to use (trae_agent)",
default="trae_agent",
)
+@click.option(
+ "--memory",
+ "memory_path",
+ type=click.Path(exists=True, dir_okay=False),
+ default=None,
+ help="Path to a .md memory file to preload into the session",
+)
def interactive(
provider: str | None = None,
model: str | None = None,
@@ -447,6 +491,7 @@ def interactive(
trajectory_file: str | None = None,
console_type: str | None = "simple",
agent_type: str | None = "trae_agent",
+ memory_path: str | None = None,
):
"""
This function starts an interactive session with Trae Agent.
@@ -492,7 +537,11 @@ def interactive(
sys.exit(1)
# Create agent
- agent = Agent(agent_type, config, trajectory_file, cli_console)
+ session_id = _generate_session_id()
+ agent = Agent(agent_type, config, trajectory_file, cli_console, session_id=session_id, memory_path=memory_path)
+
+ if memory_path and agent.agent.long_term_memory:
+ console.print(f"[cyan]Preloaded memory from: {memory_path}[/cyan]")
# Get the actual trajectory file path (in case it was auto-generated)
trajectory_file = agent.trajectory_file
@@ -722,6 +771,182 @@ def tools():
console.print(tools_table)
+@cli.command()
+@click.option(
+ "--trajectory-file",
+ "-t",
+ required=True,
+ help="Path to a trajectory JSON file to extract memory from",
+)
+@click.option(
+ "--output-dir",
+ "-o",
+ default="memory/",
+ help="Output directory for the memory Markdown file",
+)
+@click.option(
+ "--config-file",
+ help="Path to configuration file (for model settings)",
+ default="trae_config.yaml",
+ envvar="TRAE_CONFIG_FILE",
+)
+def memory(
+ trajectory_file: str,
+ output_dir: str,
+ config_file: str,
+):
+ """Extract long-term memory from a trajectory file."""
+ import json
+
+ from trae_agent.agent.agent_basics import AgentStep, AgentStepState
+ from trae_agent.tools.base import ToolCall, ToolResult
+ from trae_agent.utils.llm_clients.llm_basics import LLMResponse, LLMUsage
+ from trae_agent.utils.long_term_memory import LongTermMemory
+
+ # Load trajectory JSON
+ traj_path = Path(trajectory_file)
+ if not traj_path.exists():
+ console.print(f"[red]Error: Trajectory file not found: {traj_path}[/red]")
+ sys.exit(1)
+
+ with open(traj_path, "r", encoding="utf-8") as f:
+ traj_data = json.load(f)
+
+ # Rebuild AgentStep objects from trajectory data
+ steps: list[AgentStep] = []
+ for step_data in traj_data.get("agent_steps", []):
+ llm_response = None
+ resp_data = step_data.get("llm_response")
+ if resp_data:
+ usage = None
+ usage_data = resp_data.get("usage")
+ if usage_data:
+ usage = LLMUsage(
+ input_tokens=usage_data.get("input_tokens", 0),
+ output_tokens=usage_data.get("output_tokens", 0),
+ )
+ tool_calls = None
+ tc_data = resp_data.get("tool_calls")
+ if tc_data:
+ tool_calls = [
+ ToolCall(
+ call_id=tc.get("call_id", ""),
+ name=tc.get("name", ""),
+ arguments=tc.get("arguments", {}),
+ )
+ for tc in tc_data
+ ]
+ llm_response = LLMResponse(
+ content=resp_data.get("content", ""),
+ model=resp_data.get("model", ""),
+ finish_reason=resp_data.get("finish_reason"),
+ usage=usage,
+ tool_calls=tool_calls,
+ )
+
+ tool_results = None
+ tr_data = step_data.get("tool_results")
+ if tr_data:
+ tool_results = [
+ ToolResult(
+ call_id=tr.get("call_id", ""),
+ name=tr.get("name", ""),
+ success=tr.get("success", False),
+ result=tr.get("result"),
+ error=tr.get("error"),
+ )
+ for tr in tr_data
+ ]
+
+ step = AgentStep(
+ step_number=step_data.get("step_number", 0),
+ state=AgentStepState(step_data.get("state", "completed")),
+ llm_response=llm_response,
+ tool_results=tool_results,
+ reflection=step_data.get("reflection"),
+ error=step_data.get("error"),
+ )
+ steps.append(step)
+
+ if not steps:
+ console.print("[yellow]No agent steps found in the trajectory file.[/yellow]")
+ sys.exit(1)
+
+ # Load config for model settings
+ config_file = resolve_config_file(config_file)
+ try:
+ config = Config.create(config_file=config_file)
+ fallback_model = config.trae_agent.model if config.trae_agent else None
+ except Exception:
+ fallback_model = None
+
+ if not fallback_model:
+ console.print("[red]Error: No model configuration found. Check your config file.[/red]")
+ sys.exit(1)
+
+ # Create LongTermMemory and extract
+ from trae_agent.utils.config import LongTermMemoryConfig
+
+ ltm_config = LongTermMemoryConfig(
+ enabled=True,
+ output_dir=output_dir,
+ model=fallback_model,
+ )
+ ltm = LongTermMemory(config=ltm_config, fallback_model=fallback_model)
+ ltm.set_task(traj_data.get("task", "Unknown"))
+
+ console.print(f"[blue]Extracting memory from {len(steps)} steps...[/blue]")
+
+ memory_path = asyncio.run(ltm.extract_and_save(steps))
+ if memory_path:
+ console.print(f"[green]Long-term memory saved to: {memory_path}[/green]")
+ # Print preview
+ md_content = Path(memory_path).read_text(encoding="utf-8")
+ from rich.markdown import Markdown as RichMarkdown
+
+ console.print(Panel(RichMarkdown(md_content), title="Memory Preview", border_style="cyan"))
+ else:
+ console.print("[yellow]No memory could be extracted from this trajectory.[/yellow]")
+
+
+@cli.command("memory-list")
+@click.option(
+ "--memory-dir",
+ default="memory/",
+ help="Directory containing memory files and index",
+)
+def memory_list(memory_dir: str):
+ """List session-memory associations from the memory index."""
+ from trae_agent.utils.long_term_memory import LongTermMemory
+
+ index_path = os.path.join(memory_dir, "index.json")
+ index = LongTermMemory.query_index(index_path)
+
+ sessions = index.get("sessions", {})
+ if not sessions:
+ console.print("[yellow]No memory index found or no sessions recorded.[/yellow]")
+ return
+
+ table = Table(title="Memory Index — Session Associations")
+ table.add_column("Session ID", style="cyan")
+ table.add_column("Task Name", style="green")
+ table.add_column("Memory Files", style="magenta")
+ table.add_column("Trajectory File", style="blue")
+ table.add_column("Created At", style="yellow")
+
+ for session_id, entry in sessions.items():
+ memory_files = "\n".join(entry.get("memory_files", []))
+ table.add_row(
+ session_id,
+ entry.get("task_name", "Unknown"),
+ memory_files,
+ entry.get("trajectory_file", ""),
+ entry.get("created_at", ""),
+ )
+
+ console.print(table)
+
+
def main():
"""Main entry point for the CLI."""
cli()
diff --git a/trae_agent/utils/cli/cli_console.py b/trae_agent/utils/cli/cli_console.py
index efcdf8171..84504fe9d 100644
--- a/trae_agent/utils/cli/cli_console.py
+++ b/trae_agent/utils/cli/cli_console.py
@@ -122,6 +122,10 @@ def set_lakeview(self, lakeview_config: LakeviewConfig | None = None):
else:
self.lake_view = None
+ def set_memory_doc(self, doc: object | None): # noqa: B027
+ """Set the memory document for visualization. Subclasses can override for custom display."""
+ pass
+
def generate_agent_step_table(agent_step: AgentStep) -> Table:
"""Log an agent step to the console."""
diff --git a/trae_agent/utils/cli/simple_console.py b/trae_agent/utils/cli/simple_console.py
index a613105b3..e8faf1cd2 100644
--- a/trae_agent/utils/cli/simple_console.py
+++ b/trae_agent/utils/cli/simple_console.py
@@ -20,6 +20,7 @@
generate_agent_step_table,
)
from trae_agent.utils.config import LakeviewConfig
+from trae_agent.utils.long_term_memory import MemoryDocument
class SimpleCLIConsole(CLIConsole):
@@ -82,6 +83,10 @@ async def start(self):
if self.lake_view and self.agent_execution:
await self._print_lakeview_summary()
+ # Print memory summary if available
+ if hasattr(self, "_memory_doc") and self._memory_doc:
+ self._print_memory_summary()
+
# Print execution summary
if self.agent_execution:
self._print_execution_summary()
@@ -107,6 +112,30 @@ def _print_step_update(
self.console.print(table)
+ def set_memory_doc(self, doc: MemoryDocument | None):
+ """Set the memory document for visualization."""
+ self._memory_doc = doc
+
+ def _print_memory_summary(self):
+ """Print long-term memory summary."""
+ doc = getattr(self, "_memory_doc", None)
+ if not doc:
+ return
+
+ self.console.print("\n" + "=" * 60)
+ self.console.print("[bold cyan]Long-term Memory[/bold cyan]")
+ self.console.print("=" * 60)
+
+ for section in doc.sections:
+ self.console.print(
+ Panel(
+ f"[bold]Problem:[/bold] {section.problem}\n[bold]Conclusion:[/bold] {section.conclusion}",
+ title=section.heading(),
+ border_style="cyan",
+ width=80,
+ )
+ )
+
async def _print_lakeview_summary(self):
"""Print lakeview summary of all completed steps."""
self.console.print("\n" + "=" * 60)
diff --git a/trae_agent/utils/config.py b/trae_agent/utils/config.py
index d95026901..951e5ba8d 100644
--- a/trae_agent/utils/config.py
+++ b/trae_agent/utils/config.py
@@ -155,6 +155,17 @@ class AgentConfig:
tools: list[str]
+@dataclass
+class LongTermMemoryConfig:
+ """Long-term memory configuration."""
+
+ enabled: bool = False
+ trigger_type: str = "manual" # "manual" | "periodic"
+ periodic_interval: int = 10 # steps between auto-extractions
+ output_dir: str = "memory/" # directory for Markdown files
+ model: ModelConfig | None = None # optional separate model; falls back to agent model
+
+
@dataclass
class TraeAgentConfig(AgentConfig):
"""
@@ -162,6 +173,7 @@ class TraeAgentConfig(AgentConfig):
"""
enable_lakeview: bool = True
+ long_term_memory: LongTermMemoryConfig | None = None
tools: list[str] = field(
default_factory=lambda: [
"bash",
@@ -271,6 +283,20 @@ def create(
}
allow_mcp_servers = yaml_config.get("allow_mcp_servers", [])
+ # Parse long_term_memory config
+ ltm_yaml = yaml_config.get("long_term_memory", None)
+ ltm_config: LongTermMemoryConfig | None = None
+ if ltm_yaml is not None:
+ ltm_model_name = ltm_yaml.get("model", None)
+ ltm_model = config_models.get(ltm_model_name) if ltm_model_name else None
+ ltm_config = LongTermMemoryConfig(
+ enabled=ltm_yaml.get("enabled", False),
+ trigger_type=ltm_yaml.get("trigger_type", "manual"),
+ periodic_interval=ltm_yaml.get("periodic_interval", 10),
+ output_dir=ltm_yaml.get("output_dir", "memory/"),
+ model=ltm_model,
+ )
+
# Parse agents
agents = yaml_config.get("agents", None)
if agents is not None and len(agents.keys()) > 0:
@@ -286,6 +312,7 @@ def create(
case "trae_agent":
trae_agent_config = TraeAgentConfig(
**agent_config,
+ long_term_memory=ltm_config,
mcp_servers_config=mcp_servers_config,
allow_mcp_servers=allow_mcp_servers,
)
diff --git a/trae_agent/utils/long_term_memory.py b/trae_agent/utils/long_term_memory.py
new file mode 100644
index 000000000..5f29cf3ff
--- /dev/null
+++ b/trae_agent/utils/long_term_memory.py
@@ -0,0 +1,330 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+# SPDX-License-Identifier: MIT
+
+"""Long-term memory extraction and visualization for trae-agent."""
+
+import json
+import re
+from dataclasses import dataclass, field
+from datetime import datetime
+from pathlib import Path
+
+from trae_agent.agent.agent_basics import AgentStep
+from trae_agent.utils.config import LongTermMemoryConfig
+from trae_agent.utils.llm_clients.llm_basics import LLMMessage
+from trae_agent.utils.llm_clients.llm_client import LLMClient
+
+MEMORY_EXTRACTOR_PROMPT = """
+Given the following execution steps of an AI agent solving a task, extract the key "problem" and "conclusion" for each logical group of steps.
+
+A "problem" describes what the agent was trying to solve or figure out.
+A "conclusion" describes what the agent discovered, decided, or accomplished.
+
+Group consecutive steps that address the same sub-problem into a single section.
+Number the sections sequentially using the step range (e.g., steps="1-3", steps="4-5").
+
+Output format — repeat this block for each group:
+
+...
+...
+
+
+Be concise. Each problem and conclusion should be at most 2 sentences.
+Focus on facts: file paths, function names, root causes, specific changes made.
+Do not include any other commentary.
+"""
+
+group_re = re.compile(r'\s*(.*?)\s*(.*?)\s*', re.DOTALL)
+
+
+@dataclass
+class MemorySection:
+ """A single section of extracted memory."""
+
+ step_range: str # e.g. "1-3"
+ problem: str
+ conclusion: str
+
+ def heading(self) -> str:
+ return f"Step {self.step_range}"
+
+
+@dataclass
+class MemoryDocument:
+ """A complete memory document extracted from agent execution."""
+
+ task_name: str
+ session_id: str = ""
+ sections: list[MemorySection] = field(default_factory=list)
+ created_at: str = ""
+ step_count: int = 0
+
+ def to_markdown(self) -> str:
+ """Render the full document as Markdown."""
+ lines = [f"# Long-term Memory — Task: {self.task_name}", ""]
+ if self.session_id:
+ lines.append(f"Session: {self.session_id}")
+ lines.append("")
+ if self.created_at:
+ lines.append(f"Generated: {self.created_at} | Steps: {self.step_count}")
+ lines.append("")
+ for section in self.sections:
+ lines.append(f"## {section.heading()}")
+ lines.append(f"**Problem**: {section.problem}")
+ lines.append(f"**Conclusion**: {section.conclusion}")
+ lines.append("")
+ return "\n".join(lines)
+
+ @classmethod
+ def from_markdown(cls, markdown_text: str) -> "MemoryDocument":
+ """Parse a Markdown file back into a MemoryDocument."""
+ task_match = re.search(r"# Long-term Memory — Task: (.+)", markdown_text)
+ task_name = task_match.group(1).strip() if task_match else "Unknown"
+
+ session_match = re.search(r"^Session: (.+)$", markdown_text, re.MULTILINE)
+ session_id = session_match.group(1).strip() if session_match else ""
+
+ sections: list[MemorySection] = []
+ # Match ## Step N-M blocks
+ section_pattern = re.compile(
+ r"## Step ([^\n]+)\n\*\*Problem\*\*: ([^\n]+)\n\*\*Conclusion\*\*: ([^\n]+)", re.DOTALL
+ )
+ for match in section_pattern.finditer(markdown_text):
+ sections.append(
+ MemorySection(
+ step_range=match.group(1).strip(),
+ problem=match.group(2).strip(),
+ conclusion=match.group(3).strip(),
+ )
+ )
+
+ return cls(task_name=task_name, session_id=session_id, sections=sections)
+
+
+class LongTermMemory:
+ """Long-term memory extraction and management system."""
+
+ def __init__(self, config: LongTermMemoryConfig, fallback_model):
+ from trae_agent.utils.config import ModelConfig
+
+ model = config.model if config.model else fallback_model
+ self._llm_client = LLMClient(model)
+ self._model_config: ModelConfig = model
+ self._config = config
+ self._sections: list[MemorySection] = []
+ self._preloaded_sections: list[MemorySection] = []
+ self._session_id: str = ""
+ self._output_dir = Path(config.output_dir)
+ self._output_dir.mkdir(parents=True, exist_ok=True)
+ self._task_name: str = ""
+
+ def set_task(self, task_name: str):
+ """Called at task start."""
+ self._task_name = task_name
+ self._sections = []
+
+ def set_session_id(self, session_id: str) -> None:
+ """Set the session ID for this memory instance."""
+ self._session_id = session_id
+
+ def load_memory(self, path: str) -> None:
+ """Load memory sections from a previously saved .md file.
+
+ The loaded sections are stored in _preloaded_sections and persist
+ across set_task() calls, providing cross-session context.
+ """
+ filepath = Path(path)
+ if not filepath.exists():
+ raise FileNotFoundError(f"Memory file not found: {path}")
+ markdown_text = filepath.read_text(encoding="utf-8")
+ doc = MemoryDocument.from_markdown(markdown_text)
+ if doc.sections:
+ self._preloaded_sections = doc.sections
+
+ def _agent_step_str(self, agent_step: AgentStep) -> str | None:
+ """Convert an AgentStep to a string for the LLM."""
+ if agent_step.llm_response is None:
+ return None
+
+ content = agent_step.llm_response.content.strip()
+
+ tool_calls_content = ""
+ if agent_step.llm_response.tool_calls is not None:
+ tool_calls_content = "\n".join(
+ f"[`{tool_call.name}`] `{tool_call.arguments}`"
+ for tool_call in agent_step.llm_response.tool_calls
+ )
+ tool_calls_content = tool_calls_content.strip()
+ content = f"{content}\n\nTool calls:\n{tool_calls_content}"
+
+ if agent_step.tool_results:
+ results_content = "\n".join(
+ f"[{r.name}] success={r.success}: {r.result or r.error or ''}"
+ for r in agent_step.tool_results
+ if r
+ )
+ if results_content:
+ content = f"{content}\n\nTool results:\n{results_content}"
+
+ return content
+
+ async def extract_memory(self, steps: list[AgentStep]) -> MemoryDocument | None:
+ """Extract memory from agent steps using LLM."""
+ # Build step text
+ step_texts: list[str] = []
+ for step in steps:
+ step_str = self._agent_step_str(step)
+ if step_str:
+ step_texts.append(f"\n{step_str}\n")
+
+ if not step_texts:
+ return None
+
+ steps_formatted = "\n\n".join(step_texts)
+
+ # Truncate if too long
+ if len(steps_formatted) > 300_000:
+ steps_formatted = steps_formatted[-300_000:]
+
+ llm_messages = [
+ LLMMessage(
+ role="user",
+ content=f"Below are the execution steps of an AI agent:\n\n{steps_formatted}",
+ ),
+ LLMMessage(role="assistant", content="I understand."),
+ LLMMessage(role="user", content=MEMORY_EXTRACTOR_PROMPT),
+ ]
+
+ self._model_config.temperature = 0.1
+ llm_response = self._llm_client.chat(
+ model_config=self._model_config,
+ messages=llm_messages,
+ reuse_history=False,
+ )
+
+ content = llm_response.content.strip()
+
+ # Retry if parsing fails
+ retry = 0
+ while retry < 10 and not group_re.search(content):
+ retry += 1
+ llm_response = self._llm_client.chat(
+ model_config=self._model_config,
+ messages=llm_messages,
+ reuse_history=False,
+ )
+ content = llm_response.content.strip()
+
+ # Parse response
+ sections: list[MemorySection] = []
+ for match in group_re.finditer(content):
+ sections.append(
+ MemorySection(
+ step_range=match.group(1).strip(),
+ problem=match.group(2).strip(),
+ conclusion=match.group(3).strip(),
+ )
+ )
+
+ if not sections:
+ return None
+
+ self._sections = sections
+
+ return MemoryDocument(
+ task_name=self._task_name,
+ sections=sections,
+ created_at=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
+ step_count=len(steps),
+ )
+
+ async def extract_and_save(self, steps: list[AgentStep]) -> str | None:
+ """Extract memory and save to Markdown file. Returns the file path."""
+ doc = await self.extract_memory(steps)
+ if doc is None:
+ return None
+ self._sections = doc.sections
+ doc.session_id = self._session_id
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ filename = f"memory_{timestamp}.md"
+ filepath = self._output_dir / filename
+ filepath.write_text(doc.to_markdown(), encoding="utf-8")
+ self._update_index(str(filepath))
+ return str(filepath)
+
+ def build_memory_message(self) -> LLMMessage | None:
+ """Build an LLMMessage containing the compressed memory summary."""
+ all_sections = self._preloaded_sections + self._sections
+ if not all_sections:
+ return None
+ content = "# Long-term Memory Summary\n\n"
+ content += "The following is a compressed summary of the agent's previous execution. "
+ content += "Use this as context instead of the full conversation history.\n\n"
+ if self._preloaded_sections:
+ content += "## Context from Previous Sessions\n\n"
+ for section in self._preloaded_sections:
+ content += f"### {section.heading()}\n"
+ content += f"**Problem**: {section.problem}\n"
+ content += f"**Conclusion**: {section.conclusion}\n\n"
+ if self._sections:
+ content += "## Context from Current Session\n\n"
+ for section in self._sections:
+ content += f"### {section.heading()}\n"
+ content += f"**Problem**: {section.problem}\n"
+ content += f"**Conclusion**: {section.conclusion}\n\n"
+ return LLMMessage(role="user", content=content)
+
+ # --- Memory index management ---
+
+ def _index_path(self) -> Path:
+ """Return the path to the memory index file."""
+ return self._output_dir / "index.json"
+
+ def _load_index(self) -> dict:
+ """Load the memory index from disk, or return empty structure."""
+ path = self._index_path()
+ if path.exists():
+ return json.loads(path.read_text(encoding="utf-8"))
+ return {"version": 1, "sessions": {}}
+
+ def _save_index(self, index: dict) -> None:
+ """Save the memory index to disk."""
+ path = self._index_path()
+ path.write_text(json.dumps(index, indent=2, ensure_ascii=False), encoding="utf-8")
+
+ def _update_index(self, memory_file_path: str) -> None:
+ """Update the index with the current session's memory file."""
+ if not self._session_id:
+ return
+ index = self._load_index()
+ sessions = index["sessions"]
+ if self._session_id not in sessions:
+ sessions[self._session_id] = {
+ "task_name": self._task_name,
+ "memory_files": [memory_file_path],
+ "trajectory_file": "",
+ "created_at": datetime.now().isoformat(),
+ }
+ else:
+ entry = sessions[self._session_id]
+ if memory_file_path not in entry["memory_files"]:
+ entry["memory_files"].append(memory_file_path)
+ self._save_index(index)
+
+ def set_trajectory_file(self, trajectory_file: str) -> None:
+ """Record the trajectory file path for this session in the index."""
+ if not self._session_id:
+ return
+ index = self._load_index()
+ sessions = index["sessions"]
+ if self._session_id in sessions:
+ sessions[self._session_id]["trajectory_file"] = trajectory_file
+ self._save_index(index)
+
+ @classmethod
+ def query_index(cls, index_path: str) -> dict:
+ """Read and return the memory index. Used by CLI commands."""
+ path = Path(index_path)
+ if not path.exists():
+ return {"version": 1, "sessions": {}}
+ return json.loads(path.read_text(encoding="utf-8"))
diff --git a/trae_agent/utils/memory_trigger.py b/trae_agent/utils/memory_trigger.py
new file mode 100644
index 000000000..645ca0ae8
--- /dev/null
+++ b/trae_agent/utils/memory_trigger.py
@@ -0,0 +1,68 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+# SPDX-License-Identifier: MIT
+
+"""Memory trigger factory for long-term memory extraction."""
+
+from abc import ABC, abstractmethod
+
+from trae_agent.agent.agent_basics import AgentStep
+
+
+class MemoryTrigger(ABC):
+ """Abstract base class for memory extraction triggers."""
+
+ @abstractmethod
+ def should_trigger(self, step: AgentStep, steps_completed: int) -> bool:
+ """Decide whether to trigger memory extraction.
+
+ Args:
+ step: The step that just completed.
+ steps_completed: Total number of completed steps so far.
+
+ Returns:
+ True if memory extraction should be triggered now.
+ """
+ pass
+
+ @abstractmethod
+ def trigger_type_name(self) -> str:
+ """Return a human-readable name for this trigger type."""
+ pass
+
+
+class ManualMemoryTrigger(MemoryTrigger):
+ """Only triggers when explicitly called by the user."""
+
+ def should_trigger(self, step: AgentStep, steps_completed: int) -> bool:
+ return False
+
+ def trigger_type_name(self) -> str:
+ return "manual"
+
+
+class PeriodicMemoryTrigger(MemoryTrigger):
+ """Triggers every N steps."""
+
+ def __init__(self, interval: int = 10):
+ self._interval = interval
+ self._last_triggered_at: int = 0
+
+ def should_trigger(self, step: AgentStep, steps_completed: int) -> bool:
+ if steps_completed > 0 and steps_completed % self._interval == 0 and steps_completed != self._last_triggered_at:
+ self._last_triggered_at = steps_completed
+ return True
+ return False
+
+ def trigger_type_name(self) -> str:
+ return f"periodic(every {self._interval} steps)"
+
+
+def create_memory_trigger(trigger_type: str, periodic_interval: int = 10) -> MemoryTrigger:
+ """Factory: create the appropriate trigger based on config."""
+ match trigger_type:
+ case "manual":
+ return ManualMemoryTrigger()
+ case "periodic":
+ return PeriodicMemoryTrigger(interval=periodic_interval)
+ case _:
+ raise ValueError(f"Unknown memory trigger_type: {trigger_type}")
diff --git a/trae_config.yaml.example b/trae_config.yaml.example
index 21f9d4d20..a471ea695 100644
--- a/trae_config.yaml.example
+++ b/trae_config.yaml.example
@@ -18,6 +18,13 @@ mcp_servers:
lakeview:
model: lakeview_model
+# long_term_memory:
+# enabled: true
+# trigger_type: periodic # "manual" or "periodic"
+# periodic_interval: 10 # steps between auto-extractions (for periodic trigger)
+# output_dir: memory/ # directory for Markdown memory files
+# # model: ltm_model # optional separate model; falls back to agent model
+
model_providers:
anthropic:
api_key: your_anthropic_api_key