Skip to content

Commit ee7836d

Browse files
authored
Merge pull request #292 from embersax/componentmemory
[MRG]add memory to each coponent
2 parents cf09393 + 574c689 commit ee7836d

File tree

9 files changed

+507
-13
lines changed

9 files changed

+507
-13
lines changed

mle/agents/advisor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from mle.function import *
66
from mle.utils import get_config, print_in_box, clean_json_string
7+
from mle.utils.component_memory import trace_component
78

89

910
def process_report(requirement: str, suggestions: dict):
@@ -136,6 +137,7 @@ def __init__(self, model, console=None, mode='normal'):
136137
self.sys_prompt += self.json_mode_prompt
137138
self.chat_history.append({"role": 'system', "content": self.sys_prompt})
138139

140+
@trace_component("advisor")
139141
def suggest(self, requirement, return_raw=False):
140142
"""
141143
Handle the query from the model query response.
@@ -163,6 +165,7 @@ def suggest(self, requirement, return_raw=False):
163165

164166
return process_report(requirement, suggestions)
165167

168+
@trace_component("advisor")
166169
def interact(self, requirement):
167170
"""
168171
Interact with the user to ask and suggest.

mle/agents/chat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from mle.function import *
44
from mle.utils import get_config, WorkflowCache
5-
5+
from mle.utils.component_memory import trace_component
66

77
class ChatAgent:
88

@@ -107,7 +107,7 @@ def greet(self):
107107

108108
self.chat_history.append({"role": "assistant", "content": greets})
109109
return greets
110-
110+
@trace_component("chat")
111111
def chat(self, user_prompt):
112112
"""
113113
Handle the response from the model streaming.

mle/agents/coder.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from mle.function import *
66
from mle.utils import get_config, print_in_box, clean_json_string
7-
7+
from mle.utils.component_memory import trace_component
88

99
def process_summary(summary_dict: dict):
1010
"""
@@ -148,7 +148,8 @@ def __init__(self, model, working_dir='.', console=None, single_file=False):
148148

149149
self.sys_prompt += self.json_mode_prompt
150150
self.chat_history.append({"role": 'system', "content": self.sys_prompt})
151-
151+
152+
@trace_component("coder")
152153
def read_requirement(self, advisor_report: str):
153154
"""
154155
Read the user requirement and the advisor report.
@@ -157,6 +158,7 @@ def read_requirement(self, advisor_report: str):
157158
"""
158159
self.chat_history.append({"role": "system", "content": advisor_report})
159160

161+
@trace_component("coder")
160162
def code(self, task_dict: dict):
161163
"""
162164
Handle the query from the model query response.
@@ -181,7 +183,8 @@ def code(self, task_dict: dict):
181183
code_summary = clean_json_string(text)
182184
code_summary.update({'task': task_dict.get('task'), 'task_description': task_dict.get('description')})
183185
return code_summary
184-
186+
187+
@trace_component("coder")
185188
def debug(self, task_dict: dict, debug_report: dict):
186189
"""
187190
Handle the query from the model query response.

mle/agents/debugger.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from mle.utils import get_config, print_in_box
55

66
from rich.console import Console
7-
7+
from mle.utils.component_memory import trace_component
88

99
def process_debug_report(debug_report):
1010
"""
@@ -143,7 +143,7 @@ def analyze_with_log(self, commands, logs):
143143
report_dict = json.loads(text)
144144
print_in_box(process_debug_report(report_dict), self.console, title="MLE Debugger", color="yellow")
145145
return report_dict
146-
146+
@trace_component("debugger")
147147
def analyze(self, code_report):
148148
"""
149149
Handle the query from the model query response.
@@ -177,4 +177,4 @@ def analyze(self, code_report):
177177
self.chat_history.append({"role": "assistant", "content": text})
178178
report_dict = json.loads(text)
179179
print_in_box(process_debug_report(report_dict), self.console, title="MLE Debugger", color="yellow")
180-
return report_dict
180+
return report_dict

mle/agents/planner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from rich.console import Console
55

66
from mle.utils import print_in_box, clean_json_string
7+
from mle.utils.component_memory import trace_component
78

89

910
def process_plan(plan_dict: dict):
@@ -82,6 +83,7 @@ def __init__(self, model, console=None):
8283
self.sys_prompt += self.json_mode_prompt
8384
self.chat_history.append({"role": 'system', "content": self.sys_prompt})
8485

86+
@trace_component("planner")
8587
def plan(self, user_prompt):
8688
"""
8789
Handle the query from the model query response.
@@ -102,6 +104,7 @@ def plan(self, user_prompt):
102104
except json.JSONDecodeError as e:
103105
return clean_json_string(text)
104106

107+
@trace_component("planner")
105108
def interact(self, user_prompt):
106109
"""
107110
Handle the query from the model query response.

mle/agents/reporter.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
from rich.console import Console
33
from time import gmtime, strftime
4-
4+
from mle.utils.component_memory import trace_component
55

66
class ReportAgent:
77

@@ -138,7 +138,8 @@ def process_knowledge(self, github_summary: dict, calendar_events: list = None,
138138

139139
self.knowledge = info_prompt
140140
return info_prompt
141-
141+
142+
@trace_component("reporter")
142143
def gen_report(self, github_summary: dict, calendar_events: list = None, okr: str = None):
143144
"""
144145
Handle the query from the model query response.
@@ -165,4 +166,4 @@ def gen_report(self, github_summary: dict, calendar_events: list = None, okr: st
165166
result_dict = json.loads(text)
166167
with open(f'progress_report_{today}.json', 'w') as f:
167168
json.dump(result_dict, f)
168-
return result_dict
169+
return result_dict

mle/cli.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pathlib import Path
99
from rich.console import Console
1010
from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn, TextColumn, BarColumn
11+
import json
1112

1213
import mle
1314
from mle.server import app
@@ -21,6 +22,7 @@
2122
)
2223
from mle.utils import CodeChunker
2324
from mle.utils import LanceDBMemory, list_files, read_file
25+
from mle.utils.component_memory import ComponentMemory
2426

2527
console = Console()
2628

@@ -303,7 +305,7 @@ def new(name):
303305
base_url = questionary.text(
304306
"What is your vLLM server URL? (default: http://localhost:8000/v1)"
305307
).ask() or "http://localhost:8000/v1"
306-
308+
307309
model_name = questionary.text(
308310
"What is the model name loaded in your vLLM server? (default: mistralai/Mistral-7B-Instruct-v0.3B)"
309311
).ask() or "mistralai/Mistral-7B-Instruct-v0.3"
@@ -433,3 +435,79 @@ def memory(add, rm, update):
433435
table_name=table_name,
434436
metadata=[{'file': file_path, 'chunk_key': k} for k, _ in chunks.items()]
435437
)
438+
439+
440+
@cli.command()
441+
@click.option('--component', type=click.Choice([
442+
'advisor', 'planner', 'coder', 'debugger', 'reporter', 'chat',
443+
'github_summarizer', 'git_summarizer'
444+
]), help='Component to view traces for')
445+
@click.option('--limit', default=5, help='Maximum number of traces to show')
446+
@click.option('--full-output', is_flag=True, help='Show complete output (not truncated)')
447+
def traces(component, limit, full_output):
448+
"""View execution traces for components."""
449+
if not component:
450+
console.print("[yellow]Please specify a component to view traces for.[/yellow]")
451+
return
452+
453+
memory = ComponentMemory(os.getcwd())
454+
traces = memory.get_recent_traces(component, limit)
455+
456+
if not traces:
457+
console.print(f"[yellow]No traces found for component: {component}[/yellow]")
458+
return
459+
460+
console.print(f"[green]Recent {component} traces:[/green]")
461+
462+
for i, trace in enumerate(traces):
463+
console.print(f"\n[bold cyan]Trace #{i+1}[/bold cyan] ({trace['timestamp']})")
464+
console.print(f"Status: {trace['status']}")
465+
466+
if trace['execution_time']:
467+
console.print(f"Execution Time: {trace['execution_time']:.2f} seconds")
468+
469+
# Show context information
470+
if trace['context']:
471+
context = trace['context']
472+
if isinstance(context, dict) and context:
473+
console.print("\n[bold]Context:[/bold]")
474+
for key, value in context.items():
475+
console.print(f" {key}: {value}")
476+
477+
# Show full input
478+
console.print("\n[bold]Input:[/bold]")
479+
if isinstance(trace['input_data'], str):
480+
if full_output:
481+
console.print(trace['input_data'])
482+
else:
483+
console.print(trace['input_data'][:500] + ("..." if len(trace['input_data']) > 500 else ""))
484+
else:
485+
# Handle dictionary or other structured data
486+
input_str = json.dumps(trace['input_data'], indent=2) if isinstance(trace['input_data'], (dict, list)) else str(trace['input_data'])
487+
if full_output:
488+
console.print(input_str)
489+
else:
490+
console.print(input_str[:500] + ("..." if len(input_str) > 500 else ""))
491+
492+
# Show full output
493+
console.print("\n[bold]Output:[/bold]")
494+
if isinstance(trace['output_data'], str):
495+
if full_output:
496+
console.print(trace['output_data'])
497+
else:
498+
console.print(trace['output_data'][:500] + ("..." if len(trace['output_data']) > 500 else ""))
499+
else:
500+
# Handle dictionary or other structured data
501+
output_str = json.dumps(trace['output_data'], indent=2) if isinstance(trace['output_data'], (dict, list)) else str(trace['output_data'])
502+
if full_output:
503+
console.print(output_str)
504+
else:
505+
console.print(output_str[:500] + ("..." if len(output_str) > 500 else ""))
506+
507+
console.print("-" * 50)
508+
509+
# Show command for seeing full output
510+
if not full_output and traces:
511+
console.print("[yellow]Tip: Use --full-output flag to see complete trace data[/yellow]")
512+
513+
memory.close()

0 commit comments

Comments
 (0)