Skip to content

Commit 574c689

Browse files
committed
Replace SQLite with LanceDB for component memory implementation and use simple text embediing for component memory
1 parent 17325cd commit 574c689

File tree

2 files changed

+165
-150
lines changed

2 files changed

+165
-150
lines changed

mle/utils/component_memory.py

Lines changed: 162 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,20 @@
99
import json
1010
import uuid
1111
import time
12-
import sqlite3
1312
import traceback
1413
import functools
1514
from datetime import datetime
1615
from typing import Dict, List, Any, Optional, Union, Tuple
1716

17+
from .memory import LanceDBMemory
18+
1819

1920
class ComponentMemory:
2021
"""
2122
Tracks and stores execution traces for different components in MLE-Agent.
2223
23-
Uses SQLite as the backend for efficient storage and querying of component traces.
24-
Each component type has its own table, and traces can be queried by component,
25-
timestamp, or content.
24+
Uses LanceDB as the backend for efficient storage and querying of component traces.
25+
Component traces are organized by component type and can be queried by various attributes.
2626
"""
2727

2828
def __init__(self, project_dir: str):
@@ -33,56 +33,16 @@ def __init__(self, project_dir: str):
3333
project_dir: The project directory path.
3434
"""
3535
self.project_dir = project_dir
36-
37-
# Ensure the .mle directory exists
38-
self.memory_dir = os.path.join(project_dir, '.mle')
39-
os.makedirs(self.memory_dir, exist_ok=True)
40-
41-
# Initialize SQLite database for storing traces
42-
self.db_path = os.path.join(self.memory_dir, 'component_traces.db')
43-
self.conn = sqlite3.connect(self.db_path)
44-
self.conn.row_factory = sqlite3.Row # Access rows by name
45-
46-
# Initialize tables
47-
self._initialize_tables()
48-
49-
def _initialize_tables(self):
50-
"""Initialize the database tables for storing component traces."""
51-
cursor = self.conn.cursor()
52-
53-
# Create a table for each component type
54-
components = [
36+
37+
# Initialize LanceDB memory as the backend storage
38+
self.memory = LanceDBMemory(project_dir)
39+
40+
# Track components for easier access
41+
self.components = [
5542
'advisor', 'planner', 'coder', 'debugger', 'reporter', 'chat',
5643
'github_summarizer', 'git_summarizer'
5744
]
5845

59-
for component in components:
60-
cursor.execute(f'''
61-
CREATE TABLE IF NOT EXISTS {component}_traces (
62-
id TEXT PRIMARY KEY,
63-
timestamp TEXT,
64-
project_name TEXT,
65-
input_data TEXT,
66-
output_data TEXT,
67-
execution_time REAL,
68-
context TEXT,
69-
status TEXT
70-
)
71-
''')
72-
73-
# Create a table for tracking relationships between traces
74-
cursor.execute('''
75-
CREATE TABLE IF NOT EXISTS trace_relationships (
76-
source_id TEXT,
77-
target_id TEXT,
78-
relationship_type TEXT,
79-
metadata TEXT,
80-
PRIMARY KEY (source_id, target_id, relationship_type)
81-
)
82-
''')
83-
84-
self.conn.commit()
85-
8646
def store_trace(self,
8747
component: str,
8848
input_data: Any,
@@ -107,32 +67,38 @@ def store_trace(self,
10767
trace_id = str(uuid.uuid4())
10868
timestamp = datetime.now().isoformat()
10969
project_name = os.path.basename(self.project_dir)
110-
111-
# Serialize complex data types
112-
input_json = self._serialize_data(input_data)
113-
output_json = self._serialize_data(output_data)
114-
context_json = self._serialize_data(context or {})
115-
116-
# Store in the appropriate table
117-
cursor = self.conn.cursor()
118-
query = f'''
119-
INSERT INTO {component}_traces
120-
(id, timestamp, project_name, input_data, output_data, execution_time, context, status)
121-
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
122-
'''
123-
124-
cursor.execute(query, (
125-
trace_id,
126-
timestamp,
127-
project_name,
128-
input_json,
129-
output_json,
130-
execution_time,
131-
context_json,
132-
status
133-
))
134-
135-
self.conn.commit()
70+
71+
# Prepare text representation for vector embedding
72+
# This combines the most important fields for semantic search
73+
if isinstance(input_data, str):
74+
input_text = input_data[:1000] # Limit length for embedding
75+
else:
76+
input_text = str(input_data)[:1000]
77+
78+
text_for_embedding = f"Component: {component}\nStatus: {status}\nInput: {input_text}"
79+
80+
# Prepare metadata containing all trace details
81+
metadata = {
82+
"trace_id": trace_id,
83+
"component": component,
84+
"timestamp": timestamp,
85+
"project_name": project_name,
86+
"execution_time": execution_time,
87+
"status": status,
88+
"input_data": self._serialize_data(input_data),
89+
"output_data": self._serialize_data(output_data),
90+
"context": self._serialize_data(context or {})
91+
}
92+
93+
# Store in the component-specific table
94+
table_name = f"component_{component}_traces"
95+
self.memory.add(
96+
texts=[text_for_embedding],
97+
metadata=[metadata],
98+
table_name=table_name,
99+
ids=[trace_id]
100+
)
101+
136102
return trace_id
137103

138104
def get_trace(self, component: str, trace_id: str) -> Optional[Dict[str, Any]]:
@@ -146,14 +112,13 @@ def get_trace(self, component: str, trace_id: str) -> Optional[Dict[str, Any]]:
146112
Returns:
147113
Dict or None: The trace data if found, None otherwise.
148114
"""
149-
cursor = self.conn.cursor()
150-
query = f"SELECT * FROM {component}_traces WHERE id = ?"
151-
cursor.execute(query, (trace_id,))
152-
153-
row = cursor.fetchone()
154-
if row:
155-
return self._row_to_dict(row)
156-
return None
115+
table_name = f"component_{component}_traces"
116+
results = self.memory.get(trace_id, table_name=table_name)
117+
118+
if not results:
119+
return None
120+
121+
return self._process_trace_result(results[0])
157122

158123
def get_recent_traces(self, component: str, limit: int = 10) -> List[Dict[str, Any]]:
159124
"""
@@ -166,11 +131,29 @@ def get_recent_traces(self, component: str, limit: int = 10) -> List[Dict[str, A
166131
Returns:
167132
List[Dict]: List of trace data dictionaries.
168133
"""
169-
cursor = self.conn.cursor()
170-
query = f"SELECT * FROM {component}_traces ORDER BY timestamp DESC LIMIT ?"
171-
cursor.execute(query, (limit,))
172-
173-
return [self._row_to_dict(row) for row in cursor.fetchall()]
134+
table_name = f"component_{component}_traces"
135+
136+
# 1. Get all IDs
137+
# 2. Get metadata for each ID
138+
# 3. Sort by timestamp
139+
# 4. Take the most recent ones
140+
141+
all_keys = self.memory.list_all_keys(table_name=table_name)
142+
if not all_keys:
143+
return []
144+
145+
# Get all traces for this component
146+
traces = []
147+
for key in all_keys:
148+
result = self.memory.get(key, table_name=table_name)
149+
if result:
150+
traces.append(self._process_trace_result(result[0]))
151+
152+
# Sort by timestamp (newest first)
153+
traces.sort(key=lambda x: x['timestamp'], reverse=True)
154+
155+
# Return only the requested number
156+
return traces[:limit]
174157

175158
def search_traces(self,
176159
component: str,
@@ -187,16 +170,19 @@ def search_traces(self,
187170
Returns:
188171
List[Dict]: List of matching trace dictionaries.
189172
"""
190-
cursor = self.conn.cursor()
191-
query = f'''
192-
SELECT * FROM {component}_traces
193-
WHERE input_data LIKE ? OR output_data LIKE ?
194-
ORDER BY timestamp DESC LIMIT ?
195-
'''
196-
search_pattern = f"%{search_text}%"
197-
cursor.execute(query, (search_pattern, search_pattern, limit))
198-
199-
return [self._row_to_dict(row) for row in cursor.fetchall()]
173+
table_name = f"component_{component}_traces"
174+
175+
# Use LanceDB's vector search capability
176+
results = self.memory.query(
177+
query_texts=[search_text],
178+
table_name=table_name,
179+
n_results=limit
180+
)
181+
182+
if not results or not results[0]:
183+
return []
184+
185+
return [self._process_trace_result(item) for item in results[0]]
200186

201187
def add_relationship(self,
202188
source_id: str,
@@ -215,21 +201,27 @@ def add_relationship(self,
215201
Returns:
216202
bool: True if relationship was added successfully.
217203
"""
218-
cursor = self.conn.cursor()
219-
metadata_json = self._serialize_data(metadata or {})
220-
221-
try:
222-
cursor.execute('''
223-
INSERT INTO trace_relationships
224-
(source_id, target_id, relationship_type, metadata)
225-
VALUES (?, ?, ?, ?)
226-
''', (source_id, target_id, relationship_type, metadata_json))
227-
228-
self.conn.commit()
229-
return True
230-
except sqlite3.IntegrityError:
231-
# Relationship already exists
232-
return False
204+
relationship_id = f"{source_id}_{target_id}_{relationship_type}"
205+
206+
relationship_text = f"Relationship: {relationship_type} from {source_id} to {target_id}"
207+
208+
relationship_metadata = {
209+
"source_id": source_id,
210+
"target_id": target_id,
211+
"relationship_type": relationship_type,
212+
"metadata": self._serialize_data(metadata or {})
213+
}
214+
215+
# Store in the relationships table
216+
table_name = "component_trace_relationships"
217+
self.memory.add(
218+
texts=[relationship_text],
219+
metadata=[relationship_metadata],
220+
table_name=table_name,
221+
ids=[relationship_id]
222+
)
223+
224+
return True
233225

234226
def get_related_traces(self,
235227
trace_id: str,
@@ -244,37 +236,47 @@ def get_related_traces(self,
244236
Returns:
245237
List[Dict]: List of related trace data.
246238
"""
247-
cursor = self.conn.cursor()
248-
239+
table_name = "component_trace_relationships"
240+
241+
# Get all relationships where this trace is the source
249242
if relationship_type:
250-
query = '''
251-
SELECT * FROM trace_relationships
252-
WHERE source_id = ? AND relationship_type = ?
253-
'''
254-
cursor.execute(query, (trace_id, relationship_type))
243+
# Get relationships with specific type
244+
results = self.memory.query(
245+
query_texts=[f"Relationship: {relationship_type} from {trace_id}"],
246+
table_name=table_name,
247+
n_results=100 # Get many potential matches
248+
)
255249
else:
256-
query = '''
257-
SELECT * FROM trace_relationships
258-
WHERE source_id = ?
259-
'''
260-
cursor.execute(query, (trace_id,))
261-
250+
# Get all relationships for this trace
251+
all_keys = self.memory.list_all_keys(table_name=table_name)
252+
results = []
253+
254+
for key in all_keys:
255+
if key.startswith(f"{trace_id}_"):
256+
rel = self.memory.get(key, table_name=table_name)
257+
if rel:
258+
results.append(rel[0])
259+
260+
if not results or (isinstance(results, list) and not results[0]):
261+
return []
262+
263+
# Process and return relationship data
262264
relationships = []
263-
for row in cursor.fetchall():
264-
rel = {
265-
'source_id': row['source_id'],
266-
'target_id': row['target_id'],
267-
'relationship_type': row['relationship_type'],
268-
'metadata': json.loads(row['metadata'])
269-
}
270-
relationships.append(rel)
271-
265+
for item in results if not isinstance(results[0], list) else results[0]:
266+
if 'metadata' in item and isinstance(item['metadata'], dict):
267+
rel_data = {
268+
'source_id': item['metadata'].get('source_id'),
269+
'target_id': item['metadata'].get('target_id'),
270+
'relationship_type': item['metadata'].get('relationship_type'),
271+
'metadata': json.loads(item['metadata'].get('metadata', '{}'))
272+
}
273+
relationships.append(rel_data)
274+
272275
return relationships
273276

274277
def close(self):
275-
"""Close the database connection."""
276-
if self.conn:
277-
self.conn.close()
278+
"""Close the memory connections."""
279+
pass
278280

279281
def _serialize_data(self, data: Any) -> str:
280282
"""Serialize data to JSON string."""
@@ -290,16 +292,27 @@ def _deserialize_data(self, json_str: str) -> Any:
290292
return json.loads(json_str)
291293
except (json.JSONDecodeError, TypeError):
292294
return json_str
293-
294-
def _row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]:
295-
"""Convert an SQLite row to a dictionary with deserialized data."""
296-
trace = dict(row)
297-
298-
# Deserialize JSON fields
299-
trace['input_data'] = self._deserialize_data(trace['input_data'])
300-
trace['output_data'] = self._deserialize_data(trace['output_data'])
301-
trace['context'] = self._deserialize_data(trace['context'])
302-
295+
296+
def _process_trace_result(self, result: Dict) -> Dict[str, Any]:
297+
"""Process a raw trace result from LanceDB into a standardized format."""
298+
if not result or 'metadata' not in result:
299+
return {}
300+
301+
metadata = result['metadata']
302+
303+
# Extract and deserialize the trace data
304+
trace = {
305+
'id': metadata.get('trace_id'),
306+
'component': metadata.get('component'),
307+
'timestamp': metadata.get('timestamp'),
308+
'project_name': metadata.get('project_name'),
309+
'execution_time': metadata.get('execution_time'),
310+
'status': metadata.get('status'),
311+
'input_data': self._deserialize_data(metadata.get('input_data', '{}')),
312+
'output_data': self._deserialize_data(metadata.get('output_data', '{}')),
313+
'context': self._deserialize_data(metadata.get('context', '{}'))
314+
}
315+
303316
return trace
304317

305318

0 commit comments

Comments
 (0)