99import json
1010import uuid
1111import time
12- import sqlite3
1312import traceback
1413import functools
1514from datetime import datetime
1615from typing import Dict , List , Any , Optional , Union , Tuple
1716
17+ from .memory import LanceDBMemory
18+
1819
1920class ComponentMemory :
2021 """
2122 Tracks and stores execution traces for different components in MLE-Agent.
2223
23- Uses SQLite as the backend for efficient storage and querying of component traces.
24- Each component type has its own table, and traces can be queried by component,
25- timestamp, or content.
24+ Uses LanceDB as the backend for efficient storage and querying of component traces.
25+ Component traces are organized by component type and can be queried by various attributes.
2626 """
2727
2828 def __init__ (self , project_dir : str ):
@@ -33,56 +33,16 @@ def __init__(self, project_dir: str):
3333 project_dir: The project directory path.
3434 """
3535 self .project_dir = project_dir
36-
37- # Ensure the .mle directory exists
38- self .memory_dir = os .path .join (project_dir , '.mle' )
39- os .makedirs (self .memory_dir , exist_ok = True )
40-
41- # Initialize SQLite database for storing traces
42- self .db_path = os .path .join (self .memory_dir , 'component_traces.db' )
43- self .conn = sqlite3 .connect (self .db_path )
44- self .conn .row_factory = sqlite3 .Row # Access rows by name
45-
46- # Initialize tables
47- self ._initialize_tables ()
48-
49- def _initialize_tables (self ):
50- """Initialize the database tables for storing component traces."""
51- cursor = self .conn .cursor ()
52-
53- # Create a table for each component type
54- components = [
36+
37+ # Initialize LanceDB memory as the backend storage
38+ self .memory = LanceDBMemory (project_dir )
39+
40+ # Track components for easier access
41+ self .components = [
5542 'advisor' , 'planner' , 'coder' , 'debugger' , 'reporter' , 'chat' ,
5643 'github_summarizer' , 'git_summarizer'
5744 ]
5845
59- for component in components :
60- cursor .execute (f'''
61- CREATE TABLE IF NOT EXISTS { component } _traces (
62- id TEXT PRIMARY KEY,
63- timestamp TEXT,
64- project_name TEXT,
65- input_data TEXT,
66- output_data TEXT,
67- execution_time REAL,
68- context TEXT,
69- status TEXT
70- )
71- ''' )
72-
73- # Create a table for tracking relationships between traces
74- cursor .execute ('''
75- CREATE TABLE IF NOT EXISTS trace_relationships (
76- source_id TEXT,
77- target_id TEXT,
78- relationship_type TEXT,
79- metadata TEXT,
80- PRIMARY KEY (source_id, target_id, relationship_type)
81- )
82- ''' )
83-
84- self .conn .commit ()
85-
8646 def store_trace (self ,
8747 component : str ,
8848 input_data : Any ,
@@ -107,32 +67,38 @@ def store_trace(self,
10767 trace_id = str (uuid .uuid4 ())
10868 timestamp = datetime .now ().isoformat ()
10969 project_name = os .path .basename (self .project_dir )
110-
111- # Serialize complex data types
112- input_json = self ._serialize_data (input_data )
113- output_json = self ._serialize_data (output_data )
114- context_json = self ._serialize_data (context or {})
115-
116- # Store in the appropriate table
117- cursor = self .conn .cursor ()
118- query = f'''
119- INSERT INTO { component } _traces
120- (id, timestamp, project_name, input_data, output_data, execution_time, context, status)
121- VALUES (?, ?, ?, ?, ?, ?, ?, ?)
122- '''
123-
124- cursor .execute (query , (
125- trace_id ,
126- timestamp ,
127- project_name ,
128- input_json ,
129- output_json ,
130- execution_time ,
131- context_json ,
132- status
133- ))
134-
135- self .conn .commit ()
70+
71+ # Prepare text representation for vector embedding
72+ # This combines the most important fields for semantic search
73+ if isinstance (input_data , str ):
74+ input_text = input_data [:1000 ] # Limit length for embedding
75+ else :
76+ input_text = str (input_data )[:1000 ]
77+
78+ text_for_embedding = f"Component: { component } \n Status: { status } \n Input: { input_text } "
79+
80+ # Prepare metadata containing all trace details
81+ metadata = {
82+ "trace_id" : trace_id ,
83+ "component" : component ,
84+ "timestamp" : timestamp ,
85+ "project_name" : project_name ,
86+ "execution_time" : execution_time ,
87+ "status" : status ,
88+ "input_data" : self ._serialize_data (input_data ),
89+ "output_data" : self ._serialize_data (output_data ),
90+ "context" : self ._serialize_data (context or {})
91+ }
92+
93+ # Store in the component-specific table
94+ table_name = f"component_{ component } _traces"
95+ self .memory .add (
96+ texts = [text_for_embedding ],
97+ metadata = [metadata ],
98+ table_name = table_name ,
99+ ids = [trace_id ]
100+ )
101+
136102 return trace_id
137103
138104 def get_trace (self , component : str , trace_id : str ) -> Optional [Dict [str , Any ]]:
@@ -146,14 +112,13 @@ def get_trace(self, component: str, trace_id: str) -> Optional[Dict[str, Any]]:
146112 Returns:
147113 Dict or None: The trace data if found, None otherwise.
148114 """
149- cursor = self .conn .cursor ()
150- query = f"SELECT * FROM { component } _traces WHERE id = ?"
151- cursor .execute (query , (trace_id ,))
152-
153- row = cursor .fetchone ()
154- if row :
155- return self ._row_to_dict (row )
156- return None
115+ table_name = f"component_{ component } _traces"
116+ results = self .memory .get (trace_id , table_name = table_name )
117+
118+ if not results :
119+ return None
120+
121+ return self ._process_trace_result (results [0 ])
157122
158123 def get_recent_traces (self , component : str , limit : int = 10 ) -> List [Dict [str , Any ]]:
159124 """
@@ -166,11 +131,29 @@ def get_recent_traces(self, component: str, limit: int = 10) -> List[Dict[str, A
166131 Returns:
167132 List[Dict]: List of trace data dictionaries.
168133 """
169- cursor = self .conn .cursor ()
170- query = f"SELECT * FROM { component } _traces ORDER BY timestamp DESC LIMIT ?"
171- cursor .execute (query , (limit ,))
172-
173- return [self ._row_to_dict (row ) for row in cursor .fetchall ()]
134+ table_name = f"component_{ component } _traces"
135+
136+ # 1. Get all IDs
137+ # 2. Get metadata for each ID
138+ # 3. Sort by timestamp
139+ # 4. Take the most recent ones
140+
141+ all_keys = self .memory .list_all_keys (table_name = table_name )
142+ if not all_keys :
143+ return []
144+
145+ # Get all traces for this component
146+ traces = []
147+ for key in all_keys :
148+ result = self .memory .get (key , table_name = table_name )
149+ if result :
150+ traces .append (self ._process_trace_result (result [0 ]))
151+
152+ # Sort by timestamp (newest first)
153+ traces .sort (key = lambda x : x ['timestamp' ], reverse = True )
154+
155+ # Return only the requested number
156+ return traces [:limit ]
174157
175158 def search_traces (self ,
176159 component : str ,
@@ -187,16 +170,19 @@ def search_traces(self,
187170 Returns:
188171 List[Dict]: List of matching trace dictionaries.
189172 """
190- cursor = self .conn .cursor ()
191- query = f'''
192- SELECT * FROM { component } _traces
193- WHERE input_data LIKE ? OR output_data LIKE ?
194- ORDER BY timestamp DESC LIMIT ?
195- '''
196- search_pattern = f"%{ search_text } %"
197- cursor .execute (query , (search_pattern , search_pattern , limit ))
198-
199- return [self ._row_to_dict (row ) for row in cursor .fetchall ()]
173+ table_name = f"component_{ component } _traces"
174+
175+ # Use LanceDB's vector search capability
176+ results = self .memory .query (
177+ query_texts = [search_text ],
178+ table_name = table_name ,
179+ n_results = limit
180+ )
181+
182+ if not results or not results [0 ]:
183+ return []
184+
185+ return [self ._process_trace_result (item ) for item in results [0 ]]
200186
201187 def add_relationship (self ,
202188 source_id : str ,
@@ -215,21 +201,27 @@ def add_relationship(self,
215201 Returns:
216202 bool: True if relationship was added successfully.
217203 """
218- cursor = self .conn .cursor ()
219- metadata_json = self ._serialize_data (metadata or {})
220-
221- try :
222- cursor .execute ('''
223- INSERT INTO trace_relationships
224- (source_id, target_id, relationship_type, metadata)
225- VALUES (?, ?, ?, ?)
226- ''' , (source_id , target_id , relationship_type , metadata_json ))
227-
228- self .conn .commit ()
229- return True
230- except sqlite3 .IntegrityError :
231- # Relationship already exists
232- return False
204+ relationship_id = f"{ source_id } _{ target_id } _{ relationship_type } "
205+
206+ relationship_text = f"Relationship: { relationship_type } from { source_id } to { target_id } "
207+
208+ relationship_metadata = {
209+ "source_id" : source_id ,
210+ "target_id" : target_id ,
211+ "relationship_type" : relationship_type ,
212+ "metadata" : self ._serialize_data (metadata or {})
213+ }
214+
215+ # Store in the relationships table
216+ table_name = "component_trace_relationships"
217+ self .memory .add (
218+ texts = [relationship_text ],
219+ metadata = [relationship_metadata ],
220+ table_name = table_name ,
221+ ids = [relationship_id ]
222+ )
223+
224+ return True
233225
234226 def get_related_traces (self ,
235227 trace_id : str ,
@@ -244,37 +236,47 @@ def get_related_traces(self,
244236 Returns:
245237 List[Dict]: List of related trace data.
246238 """
247- cursor = self .conn .cursor ()
248-
239+ table_name = "component_trace_relationships"
240+
241+ # Get all relationships where this trace is the source
249242 if relationship_type :
250- query = '''
251- SELECT * FROM trace_relationships
252- WHERE source_id = ? AND relationship_type = ?
253- '''
254- cursor .execute (query , (trace_id , relationship_type ))
243+ # Get relationships with specific type
244+ results = self .memory .query (
245+ query_texts = [f"Relationship: { relationship_type } from { trace_id } " ],
246+ table_name = table_name ,
247+ n_results = 100 # Get many potential matches
248+ )
255249 else :
256- query = '''
257- SELECT * FROM trace_relationships
258- WHERE source_id = ?
259- '''
260- cursor .execute (query , (trace_id ,))
261-
250+ # Get all relationships for this trace
251+ all_keys = self .memory .list_all_keys (table_name = table_name )
252+ results = []
253+
254+ for key in all_keys :
255+ if key .startswith (f"{ trace_id } _" ):
256+ rel = self .memory .get (key , table_name = table_name )
257+ if rel :
258+ results .append (rel [0 ])
259+
260+ if not results or (isinstance (results , list ) and not results [0 ]):
261+ return []
262+
263+ # Process and return relationship data
262264 relationships = []
263- for row in cursor .fetchall ():
264- rel = {
265- 'source_id' : row ['source_id' ],
266- 'target_id' : row ['target_id' ],
267- 'relationship_type' : row ['relationship_type' ],
268- 'metadata' : json .loads (row ['metadata' ])
269- }
270- relationships .append (rel )
271-
265+ for item in results if not isinstance (results [0 ], list ) else results [0 ]:
266+ if 'metadata' in item and isinstance (item ['metadata' ], dict ):
267+ rel_data = {
268+ 'source_id' : item ['metadata' ].get ('source_id' ),
269+ 'target_id' : item ['metadata' ].get ('target_id' ),
270+ 'relationship_type' : item ['metadata' ].get ('relationship_type' ),
271+ 'metadata' : json .loads (item ['metadata' ].get ('metadata' , '{}' ))
272+ }
273+ relationships .append (rel_data )
274+
272275 return relationships
273276
274277 def close (self ):
275- """Close the database connection."""
276- if self .conn :
277- self .conn .close ()
278+ """Close the memory connections."""
279+ pass
278280
279281 def _serialize_data (self , data : Any ) -> str :
280282 """Serialize data to JSON string."""
@@ -290,16 +292,27 @@ def _deserialize_data(self, json_str: str) -> Any:
290292 return json .loads (json_str )
291293 except (json .JSONDecodeError , TypeError ):
292294 return json_str
293-
294- def _row_to_dict (self , row : sqlite3 .Row ) -> Dict [str , Any ]:
295- """Convert an SQLite row to a dictionary with deserialized data."""
296- trace = dict (row )
297-
298- # Deserialize JSON fields
299- trace ['input_data' ] = self ._deserialize_data (trace ['input_data' ])
300- trace ['output_data' ] = self ._deserialize_data (trace ['output_data' ])
301- trace ['context' ] = self ._deserialize_data (trace ['context' ])
302-
295+
296+ def _process_trace_result (self , result : Dict ) -> Dict [str , Any ]:
297+ """Process a raw trace result from LanceDB into a standardized format."""
298+ if not result or 'metadata' not in result :
299+ return {}
300+
301+ metadata = result ['metadata' ]
302+
303+ # Extract and deserialize the trace data
304+ trace = {
305+ 'id' : metadata .get ('trace_id' ),
306+ 'component' : metadata .get ('component' ),
307+ 'timestamp' : metadata .get ('timestamp' ),
308+ 'project_name' : metadata .get ('project_name' ),
309+ 'execution_time' : metadata .get ('execution_time' ),
310+ 'status' : metadata .get ('status' ),
311+ 'input_data' : self ._deserialize_data (metadata .get ('input_data' , '{}' )),
312+ 'output_data' : self ._deserialize_data (metadata .get ('output_data' , '{}' )),
313+ 'context' : self ._deserialize_data (metadata .get ('context' , '{}' ))
314+ }
315+
303316 return trace
304317
305318
0 commit comments