|
| 1 | +from datetime import datetime |
1 | 2 | import uuid |
2 | | -from typing import List, Dict, Optional |
| 3 | +from typing import List, Dict, Optional, Any |
3 | 4 |
|
4 | 5 | import lancedb |
5 | 6 | from lancedb.embeddings import get_registry |
@@ -340,4 +341,162 @@ def reset(self): |
340 | 341 | Returns: |
341 | 342 | Any: Result of the memory client's reset operation. |
342 | 343 | """ |
343 | | - return self.client.reset(agent_id=self.agent_id) |
| 344 | + return self.client.reset() |
| 345 | + |
| 346 | + |
| 347 | +class HybridMemory: |
| 348 | + """ |
| 349 | + A hybrid memory system that integrates a slow, long-term memory (e.g., Mem0) |
| 350 | + with a fast, high-recall memory (e.g., LanceDB) to support dynamic memory |
| 351 | + consolidation and retrieval for LLM agents. |
| 352 | +
|
| 353 | + Attributes: |
| 354 | + slow_memory (Mem0): The long-term, slower-access memory backend. |
| 355 | + fast_memory (LanceDBMemory): The short-term, fast-access vector memory backend. |
| 356 | + """ |
| 357 | + |
| 358 | + def __init__(self, slow_memory: Mem0, fast_memory: LanceDBMemory): |
| 359 | + """ |
| 360 | + Initialize the HybridMemory with given slow and fast memory backends. |
| 361 | +
|
| 362 | + Args: |
| 363 | + slow_memory (Mem0): An instance of slow memory (long-term storage). |
| 364 | + fast_memory (LanceDBMemory): An instance of fast memory (vector store). |
| 365 | + """ |
| 366 | + self.slow_memory: Mem0 = slow_memory |
| 367 | + self.fast_memory: LanceDBMemory = fast_memory |
| 368 | + |
| 369 | + def add( |
| 370 | + self, |
| 371 | + messages: List[Dict[str, str]], |
| 372 | + metadata: Dict[str, Any] = None, |
| 373 | + prompt: str = None, |
| 374 | + ): |
| 375 | + """ |
| 376 | + Add a set of messages to the slow memory store with optional prompt context. |
| 377 | +
|
| 378 | + Args: |
| 379 | + messages (List[Dict[str, str]]): Conversation messages to store. |
| 380 | + metadata (Dict[str, Any]): Metadata associated with the memory. |
| 381 | + prompt (str, optional): An optional prompt or context. |
| 382 | + """ |
| 383 | + return self.slow_memory.add( |
| 384 | + messages=messages, |
| 385 | + metadata=metadata, |
| 386 | + prompt=prompt, |
| 387 | + infer=prompt is not None, |
| 388 | + ) |
| 389 | + |
| 390 | + def query( |
| 391 | + self, |
| 392 | + query: str, |
| 393 | + n_results: int = 5, |
| 394 | + fast_query: bool = True, |
| 395 | + ): |
| 396 | + """ |
| 397 | + Query memory for relevant items from fast memory and optionally from slow memory. |
| 398 | +
|
| 399 | + Args: |
| 400 | + query (str): The search query string. |
| 401 | + n_results (int): Number of top results to retrieve. |
| 402 | + fast_query (bool): If True, only query fast memory; otherwise, include slow memory. |
| 403 | +
|
| 404 | + Returns: |
| 405 | + List[Dict]: Retrieved memory items. |
| 406 | + """ |
| 407 | + results = self.fast_memory.query([query], n_results=n_results) |
| 408 | + if not fast_query: |
| 409 | + results.extend(self.slow_memory.query(query, n_results=n_results)) |
| 410 | + return results |
| 411 | + |
| 412 | + def reset(self, only_reset_slow_memory: bool = True): |
| 413 | + """ |
| 414 | + Reset memory backends to empty state. |
| 415 | +
|
| 416 | + Args: |
| 417 | + only_reset_slow_memory (bool): If True, only reset slow memory; otherwise reset both. |
| 418 | + """ |
| 419 | + self.slow_memory.reset() |
| 420 | + if not only_reset_slow_memory: |
| 421 | + self.fast_memory.reset() |
| 422 | + |
| 423 | + def last_n_consolidate(self, n: int, limit: int = 1000): |
| 424 | + """ |
| 425 | + Consolidate the most recent N entries from slow memory into fast memory. |
| 426 | +
|
| 427 | + Warning: |
| 428 | + Performs in-memory sort which can be memory intensive. |
| 429 | +
|
| 430 | + Args: |
| 431 | + n (int): Number of most recent memory items to consolidate. |
| 432 | + limit (int): Maximum number of items to retrieve from slow memory. |
| 433 | +
|
| 434 | + Returns: |
| 435 | + List[Dict]: The last N memory items that were consolidated. |
| 436 | + """ |
| 437 | + # This method performs a full in-memory sort of all memory entries, which |
| 438 | + # may result in significant memory and CPU usage if the memory store is |
| 439 | + # large. Use with caution when the number of stored memory items is large. |
| 440 | + items = self.slow_memory.get_all(n_results=limit)["results"] |
| 441 | + |
| 442 | + # TODO: ranking memory items with timestamp iteratively |
| 443 | + items = sorted( |
| 444 | + items, |
| 445 | + key=lambda x: datetime.fromisoformat( |
| 446 | + x.get("updated_at") or x.get("created_at") |
| 447 | + ), |
| 448 | + reverse=True, |
| 449 | + ) |
| 450 | + |
| 451 | + last_n_items = items[:n] |
| 452 | + for item in last_n_items: |
| 453 | + self.fast_memory.add( |
| 454 | + texts=[item["memory"]], |
| 455 | + ) |
| 456 | + return last_n_items |
| 457 | + |
| 458 | + def top_k_consolidate( |
| 459 | + self, k: int, metadata_key: str, reverse=False, limit: int = 1000 |
| 460 | + ): |
| 461 | + """ |
| 462 | + Consolidate top-K entries from slow memory based on a metadata key. |
| 463 | +
|
| 464 | + Warning: |
| 465 | + Performs full in-memory sort and should be used cautiously on large datasets. |
| 466 | +
|
| 467 | + Args: |
| 468 | + k (int): Number of top memory items to consolidate. |
| 469 | + metadata_key (str): Metadata key used for sorting and selection. |
| 470 | + reverse (bool): Whether to sort in descending order. |
| 471 | + limit (int): Maximum number of items to retrieve from slow memory. |
| 472 | +
|
| 473 | + Returns: |
| 474 | + List[Dict]: The top-K memory items that were consolidated. |
| 475 | + """ |
| 476 | + # This method performs a full in-memory sort of all memory entries, which |
| 477 | + # may result in significant memory and CPU usage if the memory store is |
| 478 | + # large. Use with caution when the number of stored memory items is large. |
| 479 | + items = self.slow_memory.get_all(n_results=limit)["results"] |
| 480 | + |
| 481 | + # TODO: ranking items with manual function iteratively |
| 482 | + items = sorted( |
| 483 | + items, key=lambda x: x["metadata"].get(metadata_key), reverse=reverse |
| 484 | + ) |
| 485 | + |
| 486 | + topk_items = items[:k] |
| 487 | + for item in topk_items: |
| 488 | + self.fast_memory.add( |
| 489 | + texts=[item["memory"]], |
| 490 | + ) |
| 491 | + return topk_items |
| 492 | + |
| 493 | + def prompt_based_consolidate(self, prompt: str): |
| 494 | + """ |
| 495 | + Consolidate memory items into fast memory based on prompt relevance. |
| 496 | +
|
| 497 | + Note: [not yet implemented] |
| 498 | +
|
| 499 | + Args: |
| 500 | + prompt (str): The guiding prompt used to select memory items. |
| 501 | + """ |
| 502 | + raise NotImplementedError |
0 commit comments