Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
659 changes: 659 additions & 0 deletions tests/utils/test_long_term_memory.py

Large diffs are not rendered by default.

23 changes: 23 additions & 0 deletions trae_agent/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def __init__(
cli_console: CLIConsole | None = None,
docker_config: dict | None = None,
docker_keep: bool = True,
session_id: str | None = None,
memory_path: str | None = None,
):
if isinstance(agent_type, str):
agent_type = AgentType(agent_type)
Expand Down Expand Up @@ -56,6 +58,23 @@ def __init__(

self.agent.set_trajectory_recorder(self.trajectory_recorder)

# Set up long-term memory
if config.trae_agent and config.trae_agent.long_term_memory and config.trae_agent.long_term_memory.enabled:
from trae_agent.utils.long_term_memory import LongTermMemory

ltm = LongTermMemory(
config=config.trae_agent.long_term_memory,
fallback_model=config.trae_agent.model,
)
self.agent.set_long_term_memory(ltm)

# Set session ID and preload memory if provided
if self.agent.long_term_memory:
if session_id:
self.agent.long_term_memory.set_session_id(session_id)
if memory_path:
self.agent.long_term_memory.load_memory(memory_path)

async def run(
self,
task: str,
Expand All @@ -64,6 +83,10 @@ async def run(
):
self.agent.new_task(task, extra_args, tool_names)

if self.agent.long_term_memory:
self.agent.long_term_memory.set_task(task)
self.agent.long_term_memory.set_trajectory_file(self.trajectory_file)

if self.agent.allow_mcp_servers:
if self.agent.cli_console:
self.agent.cli_console.print("Initialising MCP tools...")
Expand Down
76 changes: 75 additions & 1 deletion trae_agent/agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from trae_agent.utils.config import AgentConfig, ModelConfig
from trae_agent.utils.llm_clients.llm_basics import LLMMessage, LLMResponse
from trae_agent.utils.llm_clients.llm_client import LLMClient
from trae_agent.utils.long_term_memory import LongTermMemory
from trae_agent.utils.memory_trigger import MemoryTrigger
from trae_agent.utils.trajectory_recorder import TrajectoryRecorder


Expand Down Expand Up @@ -77,6 +79,11 @@ def __init__(
# Trajectory recorder
self._trajectory_recorder: TrajectoryRecorder | None = None

# Long-term memory
self._long_term_memory: LongTermMemory | None = None
self._memory_trigger: MemoryTrigger | None = None
self._current_execution: AgentExecution | None = None

# CKG tool-specific: clear the older CKG databases
clear_older_ckg()

Expand All @@ -95,6 +102,23 @@ def set_trajectory_recorder(self, recorder: TrajectoryRecorder | None) -> None:
# Also set it on the LLM client
self._llm_client.set_trajectory_recorder(recorder)

@property
def long_term_memory(self) -> LongTermMemory | None:
"""Get the long-term memory system for this agent."""
return self._long_term_memory

def set_long_term_memory(self, ltm: LongTermMemory | None) -> None:
"""Set the long-term memory system and its trigger."""
from trae_agent.utils.memory_trigger import create_memory_trigger

self._long_term_memory = ltm
if ltm is not None:
self._memory_trigger = create_memory_trigger(
ltm._config.trigger_type, ltm._config.periodic_interval
)
else:
self._memory_trigger = None

@property
def cli_console(self) -> CLIConsole | None:
"""Get the CLI console for this agent."""
Expand Down Expand Up @@ -153,6 +177,7 @@ async def execute_task(self) -> AgentExecution:

start_time = time.time()
execution = AgentExecution(task=self._task, steps=[])
self._current_execution = execution
step: AgentStep | None = None

try:
Expand All @@ -167,6 +192,16 @@ async def execute_task(self) -> AgentExecution:
await self._finalize_step(
step, messages, execution
) # record trajectory for this step and update the CLI console
# Check memory trigger
if self._memory_trigger and self._long_term_memory and self._memory_trigger.should_trigger(step, len(execution.steps)):
memory_path = await self._long_term_memory.extract_and_save(
execution.steps
)
if memory_path and self._cli_console:
self._cli_console.print(
f"[Long-term Memory] Extracted and saved to: {memory_path}",
color="cyan",
)
if execution.agent_state == AgentState.COMPLETED:
break
step_number += 1
Expand Down Expand Up @@ -213,7 +248,11 @@ async def _run_llm_step(
step.state = AgentStepState.THINKING
self._update_cli_console(step, execution)
# Get LLM response
llm_response = self._llm_client.chat(messages, self._model_config, self._tools)
# Optional memory-based context compression
effective_messages = messages
if self._long_term_memory and len(messages) > 20:
effective_messages = self.inject_memory_into_messages(messages)
llm_response = self._llm_client.chat(effective_messages, self._model_config, self._tools)
step.llm_response = llm_response

# Display step with LLM response
Expand Down Expand Up @@ -350,3 +389,38 @@ async def _tool_call_handler(
messages.append(LLMMessage(role="assistant", content=reflection))

return messages

def inject_memory_into_messages(
self, messages: list[LLMMessage], keep_recent: int = 4
) -> list[LLMMessage]:
"""Replace older messages with a compressed memory summary.

Keeps the system message, the memory summary message, and the last
`keep_recent` messages.
"""
if not self._long_term_memory:
return messages

memory_msg = self._long_term_memory.build_memory_message()
if memory_msg is None:
return messages

# Always keep system message (index 0)
result = [messages[0]]

# Add memory summary
result.append(memory_msg)

# Keep the most recent messages
if len(messages) > keep_recent:
result.extend(messages[-keep_recent:])
else:
result.extend(messages[1:])

return result

async def extract_memory_now(self) -> str | None:
"""Manually trigger memory extraction. Returns the path to the saved Markdown file."""
if not self._long_term_memory or not self._current_execution:
return None
return await self._long_term_memory.extract_and_save(self._current_execution.steps)
Loading