Skip to content

Commit 403a5aa

Browse files
committed
Make memory extraction configurable
1 parent ce537c4 commit 403a5aa

File tree

4 files changed

+33
-20
lines changed

4 files changed

+33
-20
lines changed

agent_memory_server/config.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,23 @@ class Settings(BaseSettings):
2828
port: int = 8000
2929
mcp_port: int = 9000
3030

31-
# Topic and NER model settings
32-
topic_model_source: Literal["NER", "LLM"] = "LLM"
33-
topic_model: str = "MaartenGr/BERTopic_Wikipedia" # LLM model here if using LLM
34-
ner_model: str = "dbmdz/bert-large-cased-finetuned-conll03-english"
31+
# The server indexes messages in long-term memory by default. If this
32+
# setting is enabled, we also extract discrete memories from message text
33+
# and save them as separate long-term memory records.
34+
enable_discrete_memory_extraction: bool = True
35+
36+
# Topic modeling
37+
topic_model_source: Literal["BERTopic", "LLM"] = "LLM"
38+
topic_model: str = (
39+
"MaartenGr/BERTopic_Wikipedia" # Use an LLM model name here if using LLM
40+
)
3541
enable_topic_extraction: bool = True
36-
enable_ner: bool = True
3742
top_k_topics: int = 3
3843

44+
# Used for extracting entities from text
45+
ner_model: str = "dbmdz/bert-large-cased-finetuned-conll03-english"
46+
enable_ner: bool = True
47+
3948
# RedisVL Settings
4049
redisvl_distance_metric: str = "COSINE"
4150
redisvl_vector_dimensions: str = "1536"

agent_memory_server/extraction.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ async def extract_topics_llm(
146146
return topics
147147

148148

149-
def extract_topics_ner(text: str, num_topics: int | None = None) -> list[str]:
149+
def extract_topics_bertopic(text: str, num_topics: int | None = None) -> list[str]:
150150
"""
151151
Extract topics from text using the BERTopic model.
152152
@@ -193,12 +193,8 @@ async def handle_extraction(text: str) -> tuple[list[str], list[str]]:
193193
# Extract topics if enabled
194194
topics = []
195195
if settings.enable_topic_extraction:
196-
# Check if the topic_model_source setting exists and use appropriate function
197-
if (
198-
hasattr(settings, "topic_model_source")
199-
and settings.topic_model_source == "NER"
200-
):
201-
topics = extract_topics_ner(text)
196+
if settings.topic_model_source == "BERTopic":
197+
topics = extract_topics_bertopic(text)
202198
else:
203199
topics = await extract_topics_llm(text)
204200

@@ -263,7 +259,10 @@ async def handle_extraction(text: str) -> tuple[list[str], list[str]]:
263259
"""
264260

265261

266-
async def extract_discrete_memories(redis: Redis | None = None):
262+
async def extract_discrete_memories(
263+
redis: Redis | None = None,
264+
deduplicate: bool = True,
265+
):
267266
"""
268267
Extract episodic and semantic memories from text using an LLM.
269268
"""
@@ -345,5 +344,5 @@ async def extract_discrete_memories(redis: Redis | None = None):
345344

346345
await index_long_term_memories(
347346
long_term_memories,
348-
deduplicate=True,
347+
deduplicate=deduplicate,
349348
)

agent_memory_server/long_term_memory.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -500,8 +500,6 @@ async def index_long_term_memories(
500500
memories: list[LongTermMemory],
501501
redis_client: Redis | None = None,
502502
deduplicate: bool = False,
503-
deduplicate_hash: bool = True,
504-
deduplicate_semantic: bool = True,
505503
vector_distance_threshold: float = 0.12,
506504
llm_client: Any = None,
507505
) -> None:
@@ -612,7 +610,14 @@ async def index_long_term_memories(
612610
await pipe.execute()
613611

614612
logger.info(f"Indexed {len(processed_memories)} memories")
615-
await background_tasks.add_task(extract_discrete_memories)
613+
if settings.enable_discrete_memory_extraction:
614+
# Extract discrete memories from the indexed messages and persist
615+
# them as separate long-term memory records. This process also
616+
# runs deduplication if requested.
617+
await background_tasks.add_task(
618+
extract_discrete_memories,
619+
deduplicate=deduplicate,
620+
)
616621

617622

618623
async def search_long_term_memories(

tests/test_extraction.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from agent_memory_server.config import settings
77
from agent_memory_server.extraction import (
88
extract_entities,
9-
extract_topics_ner,
9+
extract_topics_bertopic,
1010
handle_extraction,
1111
)
1212

@@ -45,7 +45,7 @@ async def test_extract_topics_success(self, mock_get_topic_model, mock_bertopic)
4545
mock_get_topic_model.return_value = mock_bertopic
4646
text = "Discussion about AI technology and business"
4747

48-
topics = extract_topics_ner(text)
48+
topics = extract_topics_bertopic(text)
4949

5050
assert set(topics) == {"technology", "business"}
5151
mock_bertopic.transform.assert_called_once_with([text])
@@ -58,7 +58,7 @@ async def test_extract_topics_no_valid_topics(
5858
mock_bertopic.transform.return_value = (np.array([-1]), np.array([0.0]))
5959
mock_get_topic_model.return_value = mock_bertopic
6060

61-
topics = extract_topics_ner("Test message")
61+
topics = extract_topics_bertopic("Test message")
6262

6363
assert topics == []
6464
mock_bertopic.transform.assert_called_once()

0 commit comments

Comments
 (0)