Skip to content

Commit 8c2ec57

Browse files
authored
🧮 add hybrid memory for long-term and fast retrieval memory
Merge pull request #311 from leeeizhang/lei/hybrid_memory
2 parents 52452f1 + 67f8e96 commit 8c2ec57

File tree

1 file changed

+161
-2
lines changed

1 file changed

+161
-2
lines changed

mle/utils/memory.py

Lines changed: 161 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from datetime import datetime
12
import uuid
2-
from typing import List, Dict, Optional
3+
from typing import List, Dict, Optional, Any
34

45
import lancedb
56
from lancedb.embeddings import get_registry
@@ -340,4 +341,162 @@ def reset(self):
340341
Returns:
341342
Any: Result of the memory client's reset operation.
342343
"""
343-
return self.client.reset(agent_id=self.agent_id)
344+
return self.client.reset()
345+
346+
347+
class HybridMemory:
348+
"""
349+
A hybrid memory system that integrates a slow, long-term memory (e.g., Mem0)
350+
with a fast, high-recall memory (e.g., LanceDB) to support dynamic memory
351+
consolidation and retrieval for LLM agents.
352+
353+
Attributes:
354+
slow_memory (Mem0): The long-term, slower-access memory backend.
355+
fast_memory (LanceDBMemory): The short-term, fast-access vector memory backend.
356+
"""
357+
358+
def __init__(self, slow_memory: Mem0, fast_memory: LanceDBMemory):
359+
"""
360+
Initialize the HybridMemory with given slow and fast memory backends.
361+
362+
Args:
363+
slow_memory (Mem0): An instance of slow memory (long-term storage).
364+
fast_memory (LanceDBMemory): An instance of fast memory (vector store).
365+
"""
366+
self.slow_memory: Mem0 = slow_memory
367+
self.fast_memory: LanceDBMemory = fast_memory
368+
369+
def add(
370+
self,
371+
messages: List[Dict[str, str]],
372+
metadata: Dict[str, Any] = None,
373+
prompt: str = None,
374+
):
375+
"""
376+
Add a set of messages to the slow memory store with optional prompt context.
377+
378+
Args:
379+
messages (List[Dict[str, str]]): Conversation messages to store.
380+
metadata (Dict[str, Any]): Metadata associated with the memory.
381+
prompt (str, optional): An optional prompt or context.
382+
"""
383+
return self.slow_memory.add(
384+
messages=messages,
385+
metadata=metadata,
386+
prompt=prompt,
387+
infer=prompt is not None,
388+
)
389+
390+
def query(
391+
self,
392+
query: str,
393+
n_results: int = 5,
394+
fast_query: bool = True,
395+
):
396+
"""
397+
Query memory for relevant items from fast memory and optionally from slow memory.
398+
399+
Args:
400+
query (str): The search query string.
401+
n_results (int): Number of top results to retrieve.
402+
fast_query (bool): If True, only query fast memory; otherwise, include slow memory.
403+
404+
Returns:
405+
List[Dict]: Retrieved memory items.
406+
"""
407+
results = self.fast_memory.query([query], n_results=n_results)
408+
if not fast_query:
409+
results.extend(self.slow_memory.query(query, n_results=n_results))
410+
return results
411+
412+
def reset(self, only_reset_slow_memory: bool = True):
413+
"""
414+
Reset memory backends to empty state.
415+
416+
Args:
417+
only_reset_slow_memory (bool): If True, only reset slow memory; otherwise reset both.
418+
"""
419+
self.slow_memory.reset()
420+
if not only_reset_slow_memory:
421+
self.fast_memory.reset()
422+
423+
def last_n_consolidate(self, n: int, limit: int = 1000):
424+
"""
425+
Consolidate the most recent N entries from slow memory into fast memory.
426+
427+
Warning:
428+
Performs in-memory sort which can be memory intensive.
429+
430+
Args:
431+
n (int): Number of most recent memory items to consolidate.
432+
limit (int): Maximum number of items to retrieve from slow memory.
433+
434+
Returns:
435+
List[Dict]: The last N memory items that were consolidated.
436+
"""
437+
# This method performs a full in-memory sort of all memory entries, which
438+
# may result in significant memory and CPU usage if the memory store is
439+
# large. Use with caution when the number of stored memory items is large.
440+
items = self.slow_memory.get_all(n_results=limit)["results"]
441+
442+
# TODO: ranking memory items with timestamp iteratively
443+
items = sorted(
444+
items,
445+
key=lambda x: datetime.fromisoformat(
446+
x.get("updated_at") or x.get("created_at")
447+
),
448+
reverse=True,
449+
)
450+
451+
last_n_items = items[:n]
452+
for item in last_n_items:
453+
self.fast_memory.add(
454+
texts=[item["memory"]],
455+
)
456+
return last_n_items
457+
458+
def top_k_consolidate(
459+
self, k: int, metadata_key: str, reverse=False, limit: int = 1000
460+
):
461+
"""
462+
Consolidate top-K entries from slow memory based on a metadata key.
463+
464+
Warning:
465+
Performs full in-memory sort and should be used cautiously on large datasets.
466+
467+
Args:
468+
k (int): Number of top memory items to consolidate.
469+
metadata_key (str): Metadata key used for sorting and selection.
470+
reverse (bool): Whether to sort in descending order.
471+
limit (int): Maximum number of items to retrieve from slow memory.
472+
473+
Returns:
474+
List[Dict]: The top-K memory items that were consolidated.
475+
"""
476+
# This method performs a full in-memory sort of all memory entries, which
477+
# may result in significant memory and CPU usage if the memory store is
478+
# large. Use with caution when the number of stored memory items is large.
479+
items = self.slow_memory.get_all(n_results=limit)["results"]
480+
481+
# TODO: ranking items with manual function iteratively
482+
items = sorted(
483+
items, key=lambda x: x["metadata"].get(metadata_key), reverse=reverse
484+
)
485+
486+
topk_items = items[:k]
487+
for item in topk_items:
488+
self.fast_memory.add(
489+
texts=[item["memory"]],
490+
)
491+
return topk_items
492+
493+
def prompt_based_consolidate(self, prompt: str):
494+
"""
495+
Consolidate memory items into fast memory based on prompt relevance.
496+
497+
Note: [not yet implemented]
498+
499+
Args:
500+
prompt (str): The guiding prompt used to select memory items.
501+
"""
502+
raise NotImplementedError

0 commit comments

Comments
 (0)