Skip to content

Commit 8e2c5f0

Browse files
Support linear memory in RedisMemory (#6972)
Co-authored-by: Eric Zhu <[email protected]>
1 parent 7fbf8ab commit 8e2c5f0

File tree

3 files changed

+207
-46
lines changed

3 files changed

+207
-46
lines changed

python/packages/autogen-ext/src/autogen_ext/memory/redis/_redis_memory.py

Lines changed: 64 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111

1212
try:
1313
from redis import Redis
14-
from redisvl.extensions.message_history import SemanticMessageHistory
14+
from redisvl.extensions.message_history import MessageHistory, SemanticMessageHistory
1515
from redisvl.utils.utils import deserialize, serialize
16+
from redisvl.utils.vectorize import HFTextVectorizer
1617
except ImportError as e:
1718
raise ImportError("To use Redis Memory RedisVL must be installed. Run `pip install autogen-ext[redisvl]`") from e
1819

@@ -29,24 +30,25 @@ class RedisMemoryConfig(BaseModel):
2930
redis_url: str = Field(default="redis://localhost:6379", description="url of the Redis instance")
3031
index_name: str = Field(default="chat_history", description="Name of the Redis collection")
3132
prefix: str = Field(default="memory", description="prefix of the Redis collection")
33+
sequential: bool = Field(
34+
default=False, description="ignore semantic similarity and simply return memories in sequential order"
35+
)
3236
distance_metric: Literal["cosine", "ip", "l2"] = "cosine"
3337
algorithm: Literal["flat", "hnsw"] = "flat"
3438
top_k: int = Field(default=10, description="Number of results to return in queries")
3539
datatype: Literal["uint8", "int8", "float16", "float32", "float64", "bfloat16"] = "float32"
3640
distance_threshold: float = Field(default=0.7, description="Minimum similarity score threshold")
37-
model_name: str | None = Field(
38-
default="sentence-transformers/all-mpnet-base-v2", description="Embedding model name"
39-
)
41+
model_name: str = Field(default="sentence-transformers/all-mpnet-base-v2", description="Embedding model name")
4042

4143

4244
class RedisMemory(Memory, Component[RedisMemoryConfig]):
4345
"""
4446
Store and retrieve memory using vector similarity search powered by RedisVL.
4547
4648
`RedisMemory` provides a vector-based memory implementation that uses RedisVL for storing and
47-
retrieving content based on semantic similarity. It enhances agents with the ability to recall
48-
contextually relevant information during conversations by leveraging vector embeddings to find
49-
similar content.
49+
retrieving content based on semantic similarity or sequential order. It enhances agents with the
50+
ability to recall relevant information during conversations by leveraging vector embeddings to
51+
find similar content.
5052
5153
This implementation requires the RedisVL extra to be installed. Install with:
5254
@@ -175,7 +177,19 @@ def __init__(self, config: RedisMemoryConfig | None = None) -> None:
175177
self.config = config or RedisMemoryConfig()
176178
client = Redis.from_url(url=self.config.redis_url) # type: ignore[reportUknownMemberType]
177179

178-
self.message_history = SemanticMessageHistory(name=self.config.index_name, redis_client=client)
180+
if self.config.sequential:
181+
self.message_history = MessageHistory(
182+
name=self.config.index_name, prefix=self.config.prefix, redis_client=client
183+
)
184+
else:
185+
vectorizer = HFTextVectorizer(model=self.config.model_name, dtype=self.config.datatype)
186+
self.message_history = SemanticMessageHistory(
187+
name=self.config.index_name,
188+
prefix=self.config.prefix,
189+
vectorizer=vectorizer,
190+
distance_threshold=self.config.distance_threshold,
191+
redis_client=client,
192+
)
179193

180194
async def update_context(
181195
self,
@@ -203,7 +217,7 @@ async def update_context(
203217
else:
204218
last_message = ""
205219

206-
query_results = await self.query(last_message)
220+
query_results = await self.query(last_message, sequential=self.config.sequential)
207221

208222
stringified_messages = "\n\n".join([str(m.content) for m in query_results.results])
209223

@@ -216,10 +230,10 @@ async def add(self, content: MemoryContent, cancellation_token: CancellationToke
216230
217231
.. note::
218232
219-
To perform semantic search over stored memories RedisMemory creates a vector embedding
220-
from the content field of a MemoryContent object. This content is assumed to be text,
221-
JSON, or Markdown, and is passed to the vector embedding model specified in
222-
RedisMemoryConfig.
233+
If RedisMemoryConfig is not set to 'sequential', to perform semantic search over stored
234+
memories RedisMemory creates a vector embedding from the content field of a
235+
MemoryContent object. This content is assumed to be text, JSON, or Markdown, and is
236+
passed to the vector embedding model specified in RedisMemoryConfig.
223237
224238
Args:
225239
content (MemoryContent): The memory content to store within Redis.
@@ -241,7 +255,7 @@ async def add(self, content: MemoryContent, cancellation_token: CancellationToke
241255
metadata = {"mime_type": mime_type}
242256
metadata.update(content.metadata if content.metadata else {})
243257
self.message_history.add_message(
244-
{"role": "user", "content": memory_content, "tool_call_id": serialize(metadata)} # type: ignore[reportArgumentType]
258+
{"role": "user", "content": memory_content, "metadata": serialize(metadata)} # type: ignore[reportArgumentType]
245259
)
246260

247261
async def query(
@@ -258,6 +272,7 @@ async def query(
258272
top_k (int): The maximum number of relevant memories to include. Defaults to 10.
259273
distance_threshold (float): The maximum distance in vector space to consider a memory
260274
semantically similar when performining cosine similarity search. Defaults to 0.7.
275+
sequential (bool): Ignore semantic similarity and return the top_k most recent memories.
261276
262277
Args:
263278
query (str | MemoryContent): query to perform vector similarity search with. If a
@@ -270,34 +285,46 @@ async def query(
270285
Returns:
271286
memoryQueryResult: Object containing memories relevant to the provided query.
272287
"""
273-
# get the query string, or raise an error for unsupported MemoryContent types
274-
if isinstance(query, str):
275-
prompt = query
276-
elif isinstance(query, MemoryContent):
277-
if query.mime_type in (MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN):
278-
prompt = str(query.content)
279-
elif query.mime_type == MemoryMimeType.JSON:
280-
prompt = serialize(query.content)
281-
else:
282-
raise NotImplementedError(
283-
f"Error: {query.mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, MemoryMimeType.MARKDOWN are currently supported."
284-
)
285-
else:
286-
raise TypeError("'query' must be either a string or MemoryContent")
287-
288288
top_k = kwargs.pop("top_k", self.config.top_k)
289289
distance_threshold = kwargs.pop("distance_threshold", self.config.distance_threshold)
290290

291-
results = self.message_history.get_relevant(
292-
prompt=prompt, # type: ignore[reportArgumentType]
293-
top_k=top_k,
294-
distance_threshold=distance_threshold,
295-
raw=False,
296-
)
291+
# if sequential memory is requested skip prompt creation
292+
sequential = bool(kwargs.pop("sequential", self.config.sequential))
293+
if self.config.sequential and not sequential:
294+
raise ValueError(
295+
"Non-sequential queries cannot be run with an underlying sequential RedisMemory. Set sequential=False in RedisMemoryConfig to enable semantic memory querying."
296+
)
297+
elif sequential or self.config.sequential:
298+
results = self.message_history.get_recent(
299+
top_k=top_k,
300+
raw=False,
301+
)
302+
else:
303+
# get the query string, or raise an error for unsupported MemoryContent types
304+
if isinstance(query, str):
305+
prompt = query
306+
elif isinstance(query, MemoryContent):
307+
if query.mime_type in (MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN):
308+
prompt = str(query.content)
309+
elif query.mime_type == MemoryMimeType.JSON:
310+
prompt = serialize(query.content)
311+
else:
312+
raise NotImplementedError(
313+
f"Error: {query.mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, MemoryMimeType.MARKDOWN are currently supported."
314+
)
315+
else:
316+
raise TypeError("'query' must be either a string or MemoryContent")
317+
318+
results = self.message_history.get_relevant( # type: ignore
319+
prompt=prompt, # type: ignore[reportArgumentType]
320+
top_k=top_k,
321+
distance_threshold=distance_threshold,
322+
raw=False,
323+
)
297324

298325
memories: List[MemoryContent] = []
299-
for result in results:
300-
metadata = deserialize(result["tool_call_id"]) # type: ignore[reportArgumentType]
326+
for result in results: # type: ignore[reportUnkownVariableType]
327+
metadata = deserialize(result["metadata"]) # type: ignore[reportArgumentType]
301328
mime_type = MemoryMimeType(metadata.pop("mime_type"))
302329
if mime_type in (MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN):
303330
memory_content = result["content"] # type: ignore[reportArgumentType]

python/packages/autogen-ext/tests/memory/test_redis_memory.py

Lines changed: 140 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ async def test_redis_memory_query_with_mock() -> None:
3636
memory = RedisMemory(config=config)
3737

3838
mock_history.get_relevant.return_value = [
39-
{"content": "test content", "tool_call_id": '{"foo": "bar", "mime_type": "text/plain"}'}
39+
{"content": "test content", "metadata": '{"foo": "bar", "mime_type": "text/plain"}'}
4040
]
4141
result = await memory.query("test")
4242
assert len(result.results) == 1
@@ -86,13 +86,26 @@ def semantic_config() -> RedisMemoryConfig:
8686
return RedisMemoryConfig(top_k=5, distance_threshold=0.5, model_name="sentence-transformers/all-mpnet-base-v2")
8787

8888

89+
@pytest.fixture
90+
def sequential_config() -> RedisMemoryConfig:
91+
"""Create base configuration using semantic memory."""
92+
return RedisMemoryConfig(top_k=5, sequential=True)
93+
94+
8995
@pytest_asyncio.fixture # type: ignore[reportUntypedFunctionDecorator]
9096
async def semantic_memory(semantic_config: RedisMemoryConfig) -> AsyncGenerator[RedisMemory]:
9197
memory = RedisMemory(semantic_config)
9298
yield memory
9399
await memory.close()
94100

95101

102+
@pytest_asyncio.fixture # type: ignore[reportUntypedFunctionDecorator]
103+
async def sequential_memory(sequential_config: RedisMemoryConfig) -> AsyncGenerator[RedisMemory]:
104+
memory = RedisMemory(sequential_config)
105+
yield memory
106+
await memory.close()
107+
108+
96109
## UNIT TESTS ##
97110
def test_memory_config() -> None:
98111
default_config = RedisMemoryConfig()
@@ -104,6 +117,7 @@ def test_memory_config() -> None:
104117
assert default_config.top_k == 10
105118
assert default_config.distance_threshold == 0.7
106119
assert default_config.model_name == "sentence-transformers/all-mpnet-base-v2"
120+
assert not default_config.sequential
107121

108122
# test we can specify each of these values
109123
url = "rediss://localhost:7010"
@@ -144,14 +158,36 @@ def test_memory_config() -> None:
144158

145159
@pytest.mark.asyncio
146160
@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally")
147-
async def test_create_semantic_memory() -> None:
148-
config = RedisMemoryConfig(index_name="semantic_agent")
161+
@pytest.mark.parametrize("sequential", [True, False])
162+
async def test_create_memory(sequential: bool) -> None:
163+
config = RedisMemoryConfig(index_name="semantic_agent", sequential=sequential)
149164
memory = RedisMemory(config=config)
150165

151166
assert memory.message_history is not None
152167
await memory.close()
153168

154169

170+
@pytest.mark.asyncio
171+
@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally")
172+
async def test_specify_vectorizer() -> None:
173+
config = RedisMemoryConfig(index_name="semantic_agent", model_name="redis/langcache-embed-v1")
174+
memory = RedisMemory(config=config)
175+
assert memory.message_history._vectorizer.dims == 768 # type: ignore[reportPrivateUsage]
176+
await memory.close()
177+
178+
config = RedisMemoryConfig(
179+
index_name="semantic_agent", model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
180+
)
181+
memory = RedisMemory(config=config)
182+
assert memory.message_history._vectorizer.dims == 384 # type: ignore[reportPrivateUsage]
183+
await memory.close()
184+
185+
# throw an error if a non-existant model name is passed
186+
config = RedisMemoryConfig(index_name="semantic_agent", model_name="not-a-real-model")
187+
with pytest.raises(OSError):
188+
memory = RedisMemory(config=config)
189+
190+
155191
@pytest.mark.asyncio
156192
@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally")
157193
async def test_update_context(semantic_memory: RedisMemory) -> None:
@@ -223,7 +259,7 @@ async def test_update_context(semantic_memory: RedisMemory) -> None:
223259

224260
@pytest.mark.asyncio
225261
@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally")
226-
async def test_add_and_query(semantic_memory: RedisMemory) -> None:
262+
async def test_add_and_query_with_string(semantic_memory: RedisMemory) -> None:
227263
content_1 = MemoryContent(
228264
content="I enjoy fruits like apples, oranges, and bananas.", mime_type=MemoryMimeType.TEXT, metadata={}
229265
)
@@ -251,6 +287,38 @@ async def test_add_and_query(semantic_memory: RedisMemory) -> None:
251287
assert memories.results[1].metadata == {"description": "additional info"}
252288

253289

290+
@pytest.mark.asyncio
291+
@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally")
292+
async def test_add_and_query_with_memory_content(semantic_memory: RedisMemory) -> None:
293+
content_1 = MemoryContent(
294+
content="I enjoy fruits like apples, oranges, and bananas.", mime_type=MemoryMimeType.TEXT, metadata={}
295+
)
296+
await semantic_memory.add(content_1)
297+
298+
# find matches with a similar query
299+
memories = await semantic_memory.query(MemoryContent(content="Fruits that I like.", mime_type=MemoryMimeType.TEXT))
300+
assert len(memories.results) == 1
301+
302+
# don't return anything for dissimilar queries
303+
no_memories = await semantic_memory.query(
304+
MemoryContent(content="The king of England", mime_type=MemoryMimeType.TEXT)
305+
)
306+
assert len(no_memories.results) == 0
307+
308+
# match multiple relevant memories
309+
content_2 = MemoryContent(
310+
content="I also like mangos and pineapples.",
311+
mime_type=MemoryMimeType.TEXT,
312+
metadata={"description": "additional info"},
313+
)
314+
await semantic_memory.add(content_2)
315+
316+
memories = await semantic_memory.query(MemoryContent(content="Fruits that I like.", mime_type=MemoryMimeType.TEXT))
317+
assert len(memories.results) == 2
318+
assert memories.results[0].metadata == {}
319+
assert memories.results[1].metadata == {"description": "additional info"}
320+
321+
254322
@pytest.mark.asyncio
255323
@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally")
256324
async def test_clear(semantic_memory: RedisMemory) -> None:
@@ -283,9 +351,16 @@ async def test_close(semantic_config: RedisMemoryConfig) -> None:
283351
## INTEGRATION TESTS ##
284352
@pytest.mark.asyncio
285353
@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally")
286-
async def test_basic_workflow(semantic_config: RedisMemoryConfig) -> None:
354+
@pytest.mark.parametrize("config_type", ["sequential", "semantic"])
355+
async def test_basic_workflow(config_type: str) -> None:
287356
"""Test basic memory operations with semantic memory."""
288-
memory = RedisMemory(config=semantic_config)
357+
if config_type == "sequential":
358+
config = RedisMemoryConfig(top_k=5, sequential=True)
359+
else:
360+
config = RedisMemoryConfig(
361+
top_k=5, distance_threshold=0.5, model_name="sentence-transformers/all-mpnet-base-v2"
362+
)
363+
memory = RedisMemory(config=config)
289364
await memory.clear()
290365

291366
await memory.add(
@@ -318,6 +393,11 @@ async def test_text_memory_type(semantic_memory: RedisMemory) -> None:
318393
assert len(results.results) > 0
319394
assert any("Simple text content" in str(r.content) for r in results.results)
320395

396+
# Query for text content with a MemoryContent object
397+
results = await semantic_memory.query(MemoryContent(content="simple text content", mime_type=MemoryMimeType.TEXT))
398+
assert len(results.results) > 0
399+
assert any("Simple text content" in str(r.content) for r in results.results)
400+
321401

322402
@pytest.mark.asyncio
323403
@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally")
@@ -419,3 +499,57 @@ async def test_query_arguments(semantic_memory: RedisMemory) -> None:
419499
# limit search to only close matches
420500
results = await semantic_memory.query("my favorite fruit are what?", distance_threshold=0.2)
421501
assert len(results.results) == 1
502+
503+
# get memories based on recency instead of relevance
504+
results = await semantic_memory.query("fast sports cars", sequential=True)
505+
assert len(results.results) == 3
506+
507+
# setting 'sequential' to False results in default behaviour
508+
results = await semantic_memory.query("my favorite fruit are what?", sequential=False)
509+
assert len(results.results) == 3
510+
511+
512+
@pytest.mark.asyncio
513+
@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally")
514+
async def test_sequential_memory_workflow(sequential_memory: RedisMemory) -> None:
515+
await sequential_memory.clear()
516+
517+
await sequential_memory.add(MemoryContent(content="my favorite fruit are apples", mime_type=MemoryMimeType.TEXT))
518+
await sequential_memory.add(
519+
MemoryContent(
520+
content="I read the encyclopedia britanica and my favorite section was on the Napoleonic Wars.",
521+
mime_type=MemoryMimeType.TEXT,
522+
)
523+
)
524+
await sequential_memory.add(
525+
MemoryContent(content="Sharks have no idea that camels exist.", mime_type=MemoryMimeType.TEXT)
526+
)
527+
await sequential_memory.add(
528+
MemoryContent(
529+
content="Python is a popular programming language used for machine learning and AI applications.",
530+
mime_type=MemoryMimeType.TEXT,
531+
)
532+
)
533+
await sequential_memory.add(
534+
MemoryContent(content="Fifth random and unrelated sentence", mime_type=MemoryMimeType.TEXT)
535+
)
536+
537+
# default search returns last 5 memories
538+
results = await sequential_memory.query("what fruits do I like?")
539+
assert len(results.results) == 5
540+
541+
# limit search to 2 results
542+
results = await sequential_memory.query("what fruits do I like?", top_k=2)
543+
assert len(results.results) == 2
544+
545+
# sequential memory does not consider semantic similarity
546+
results = await sequential_memory.query("How do I make peanut butter sandwiches?")
547+
assert len(results.results) == 5
548+
549+
# seting 'sequential' to True in query method is redundant
550+
results = await sequential_memory.query("fast sports cars", sequential=True)
551+
assert len(results.results) == 5
552+
553+
# setting 'sequential' to False with a Sequential memory object raises an error
554+
with pytest.raises(ValueError):
555+
_ = await sequential_memory.query("my favorite fruit are what?", sequential=False)

0 commit comments

Comments
 (0)