Skip to content

Commit f7333e9

Browse files
Lili Maliangchg
andcommitted
Add mem0_memory Support for PostgreSQL
Co-authored-by: liangchg <[email protected]>
1 parent 55354a1 commit f7333e9

File tree

3 files changed

+355
-4
lines changed

3 files changed

+355
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ mem0_memory = [
8282
# Need to be optional as a fix for https://github.com/strands-agents/docs/issues/19
8383
"mem0ai>=0.1.99,<1.0.0",
8484
"opensearch-py>=2.8.0,<3.0.0",
85+
"psycopg2-binary",
8586
]
8687
local_chromium_browser = ["nest-asyncio>=1.5.0,<2.0.0", "playwright>=1.42.0,<2.0.0"]
8788
agent_core_browser = [

src/strands_tools/mem0_memory.py

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,33 @@ class Mem0ServiceClient:
177177
},
178178
}
179179

180+
def _get_postgresql_config(self) -> Dict:
181+
"""Get PostgreSQL configuration based on the current provider."""
182+
# Start with the default embedder and llm config
183+
config = {
184+
"embedder": self.DEFAULT_CONFIG["embedder"].copy(),
185+
"llm": self.DEFAULT_CONFIG["llm"].copy(),
186+
}
187+
188+
189+
190+
# Add PostgreSQL vector store configuration
191+
config["vector_store"] = {
192+
"provider": "pgvector",
193+
"config": {
194+
"host": os.environ.get("POSTGRESQL_HOST"),
195+
"port": int(os.environ.get("POSTGRESQL_PORT", 5432)),
196+
"user": os.environ.get("POSTGRESQL_USER"),
197+
"password": os.environ.get("POSTGRESQL_PASSWORD"),
198+
"dbname": os.environ.get("DB_NAME", "postgres"),
199+
"collection_name": os.environ.get("DB_COLLECTION_NAME", "mem0_memories"),
200+
"embedding_model_dims": 1024,
201+
}
202+
}
203+
204+
return config
205+
206+
180207
def __init__(self, config: Optional[Dict] = None):
181208
"""Initialize the Mem0 service client.
182209
@@ -208,6 +235,10 @@ def _initialize_client(self, config: Optional[Dict] = None) -> Any:
208235
logger.debug("Using Neptune Analytics graph backend (Mem0Memory with Neptune Analytics)")
209236
config = self._configure_neptune_analytics_backend(config)
210237

238+
if os.environ.get("POSTGRESQL_HOST"):
239+
logger.info("Using PostgreSQL backend (Mem0Memory with PostgreSQL)")
240+
return self._initialize_postgresql_client(config)
241+
211242
if os.environ.get("OPENSEARCH_HOST"):
212243
logger.debug("Using OpenSearch backend (Mem0Memory with OpenSearch)")
213244
return self._initialize_opensearch_client(config)
@@ -231,6 +262,37 @@ def _configure_neptune_analytics_backend(self, config: Optional[Dict] = None) ->
231262
}
232263
return config
233264

265+
def _initialize_postgresql_client(self, config: Optional[Dict] = None) -> Mem0Memory:
266+
"""Initialize a Mem0 client with PostgreSQL backend.
267+
268+
Args:
269+
config: Optional configuration dictionary to override defaults.
270+
271+
Returns:
272+
An initialized Mem0Memory instance configured for PostgreSQL.
273+
274+
Raises:
275+
ValueError: If required PostgreSQL environment variables are missing.
276+
"""
277+
# Validate required environment variables
278+
required_vars = ["POSTGRESQL_HOST", "POSTGRESQL_USER", "POSTGRESQL_PASSWORD"]
279+
missing_vars = [var for var in required_vars if not os.environ.get(var)]
280+
if missing_vars:
281+
raise ValueError(f"Missing required PostgreSQL environment variables: {', '.join(missing_vars)}")
282+
283+
# Get PostgreSQL configuration
284+
pg_config = self._get_postgresql_config()
285+
286+
# Validate OpenAI API key if using OpenAI
287+
provider = os.environ.get("MEM0_LLM_PROVIDER", "aws_bedrock")
288+
if provider == "openai" and not os.environ.get("OPENAI_API_KEY"):
289+
raise ValueError("OPENAI_API_KEY environment variable is required when using OpenAI provider")
290+
291+
# Merge with user-provided config if any
292+
merged_config = self._merge_configs(pg_config, config)
293+
294+
return Mem0Memory.from_config(config_dict=merged_config)
295+
234296
def _initialize_opensearch_client(self, config: Optional[Dict] = None) -> Mem0Memory:
235297
"""Initialize a Mem0 client with OpenSearch backend.
236298
@@ -296,12 +358,24 @@ def _merge_config(self, config: Optional[Dict] = None) -> Dict:
296358
Returns:
297359
A merged configuration dictionary.
298360
"""
299-
merged_config = self.DEFAULT_CONFIG.copy()
300-
if not config:
361+
return self._merge_configs(self.DEFAULT_CONFIG, config)
362+
363+
def _merge_configs(self, base_config: Dict, override_config: Optional[Dict] = None) -> Dict:
364+
"""Merge two configuration dictionaries.
365+
366+
Args:
367+
base_config: Base configuration dictionary
368+
override_config: Optional configuration to merge into base
369+
370+
Returns:
371+
A merged configuration dictionary.
372+
"""
373+
merged_config = base_config.copy()
374+
if not override_config:
301375
return merged_config
302376

303-
# Deep merge the configs
304-
for key, value in config.items():
377+
# Merge the configs
378+
for key, value in override_config.items():
305379
if key in merged_config and isinstance(value, dict) and isinstance(merged_config[key], dict):
306380
merged_config[key].update(value)
307381
else:

tests/test_mem0.py

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,3 +523,279 @@ def test_faiss_client(mock_mem0_memory, mock_tool):
523523
# Assertions
524524
assert result["status"] == "success"
525525
assert "Test memory content" in str(result["content"][0]["text"])
526+
527+
528+
@patch.dict(
529+
os.environ,
530+
{
531+
"POSTGRESQL_HOST": "test-cluster.cluster-abc123.us-west-2.rds.amazonaws.com",
532+
"POSTGRESQL_USER": "test_user",
533+
"POSTGRESQL_PASSWORD": "test_password",
534+
"DB_NAME": "test_db",
535+
"MEM0_LLM_PROVIDER": "openai",
536+
"MEM0_LLM_MODEL": "gpt-4",
537+
"MEM0_EMBEDDER_PROVIDER": "openai",
538+
"MEM0_EMBEDDER_MODEL": "text-embedding-3-large",
539+
"OPENAI_API_KEY": "test-api-key",
540+
},
541+
)
542+
@patch("strands_tools.mem0_memory.Mem0ServiceClient")
543+
def test_postgresql_store_memory(mock_mem0_client, mock_mem0_service_client, mock_tool):
544+
"""Test PostgreSQL store memory functionality."""
545+
# Setup mocks
546+
mock_mem0_client.return_value = mock_mem0_service_client
547+
548+
# Configure the mock_tool
549+
mock_tool.get.side_effect = lambda key, default=None: {
550+
"toolUseId": "test-id",
551+
"input": {
552+
"action": "store",
553+
"content": "Test memory content",
554+
"user_id": "test_user",
555+
"metadata": {"category": "test"},
556+
},
557+
}.get(key, default)
558+
559+
# Mock data
560+
store_response = [
561+
{
562+
"event": "store",
563+
"memory": "Test memory content",
564+
"id": "mem123",
565+
"created_at": "2024-03-20T10:00:00Z",
566+
}
567+
]
568+
569+
# Configure mocks
570+
mock_mem0_service_client.store_memory.return_value = store_response
571+
572+
# Call the memory function
573+
result = mem0_memory.mem0_memory(tool=mock_tool)
574+
575+
# Assertions
576+
assert result["status"] == "success"
577+
assert result["content"][0]["text"] == json.dumps(store_response, indent=2)
578+
579+
580+
@patch.dict(
581+
os.environ,
582+
{
583+
"POSTGRESQL_HOST": "test-cluster.cluster-abc123.us-west-2.rds.amazonaws.com",
584+
"POSTGRESQL_USER": "test_user",
585+
"POSTGRESQL_PASSWORD": "test_password",
586+
"DB_NAME": "test_db",
587+
},
588+
)
589+
@patch("strands_tools.mem0_memory.Mem0ServiceClient")
590+
def test_postgresql_get_memory(mock_mem0_client, mock_mem0_service_client, mock_tool):
591+
"""Test PostgreSQL get memory functionality."""
592+
# Setup mocks
593+
mock_mem0_client.return_value = mock_mem0_service_client
594+
595+
# Configure the mock_tool
596+
mock_tool.get.side_effect = lambda key, default=None: {
597+
"toolUseId": "test-id",
598+
"input": {"action": "get", "memory_id": "mem123"},
599+
}.get(key, default)
600+
601+
# Mock data
602+
get_response = {
603+
"id": "mem123",
604+
"memory": "Test memory content",
605+
"created_at": "2024-03-20T10:00:00Z",
606+
"user_id": "test_user",
607+
"metadata": {"category": "test"},
608+
}
609+
610+
# Configure mocks
611+
mock_mem0_service_client.get_memory.return_value = get_response
612+
613+
# Call the memory function
614+
result = mem0_memory.mem0_memory(tool=mock_tool)
615+
616+
# Assertions
617+
assert result["status"] == "success"
618+
assert isinstance(result["content"], list)
619+
assert len(result["content"]) > 0
620+
assert "text" in result["content"][0]
621+
memory = json.loads(result["content"][0]["text"])
622+
assert memory["id"] == "mem123"
623+
assert memory["memory"] == "Test memory content"
624+
assert memory["user_id"] == "test_user"
625+
assert memory["metadata"] == {"category": "test"}
626+
627+
628+
@patch.dict(
629+
os.environ,
630+
{
631+
"POSTGRESQL_HOST": "test-cluster.cluster-abc123.us-west-2.rds.amazonaws.com",
632+
"POSTGRESQL_USER": "test_user",
633+
"POSTGRESQL_PASSWORD": "test_password",
634+
"DB_NAME": "test_db",
635+
},
636+
)
637+
@patch("strands_tools.mem0_memory.Mem0ServiceClient")
638+
def test_postgresql_list_memories(mock_mem0_client, mock_mem0_service_client, mock_tool):
639+
"""Test PostgreSQL list memories functionality."""
640+
# Setup mocks
641+
mock_mem0_client.return_value = mock_mem0_service_client
642+
643+
# Configure the mock_tool
644+
mock_tool.get.side_effect = lambda key, default=None: {
645+
"toolUseId": "test-id",
646+
"input": {"action": "list", "user_id": "test_user"},
647+
}.get(key, default)
648+
649+
# Mock data for list_memories response
650+
list_response = {
651+
"results": [
652+
{
653+
"id": "mem123",
654+
"memory": "Test memory content",
655+
"created_at": "2024-03-20T10:00:00Z",
656+
"user_id": "test_user",
657+
"metadata": {"category": "test"},
658+
}
659+
]
660+
}
661+
662+
# Configure mocks
663+
mock_mem0_service_client.list_memories.return_value = list_response
664+
665+
# Call the memory function
666+
result = mem0_memory.mem0_memory(tool=mock_tool)
667+
668+
# Assertions
669+
assert result["status"] == "success"
670+
assert isinstance(result["content"], list)
671+
assert len(result["content"]) > 0
672+
assert "text" in result["content"][0]
673+
# Parse the JSON string in text
674+
memories = json.loads(result["content"][0]["text"])
675+
assert isinstance(memories, list)
676+
assert len(memories) > 0
677+
assert "id" in memories[0]
678+
assert memories[0]["id"] == "mem123"
679+
680+
681+
@patch.dict(
682+
os.environ,
683+
{
684+
"POSTGRESQL_HOST": "test-cluster.cluster-abc123.us-west-2.rds.amazonaws.com",
685+
"POSTGRESQL_USER": "test_user",
686+
"POSTGRESQL_PASSWORD": "test_password",
687+
"DB_NAME": "test_db",
688+
},
689+
)
690+
@patch("strands_tools.mem0_memory.Mem0ServiceClient")
691+
def test_postgresql_retrieve_memories(mock_mem0_client, mock_mem0_service_client, mock_tool):
692+
"""Test PostgreSQL retrieve memories functionality."""
693+
# Setup mocks
694+
mock_mem0_client.return_value = mock_mem0_service_client
695+
696+
# Configure the mock_tool
697+
mock_tool.get.side_effect = lambda key, default=None: {
698+
"toolUseId": "test-id",
699+
"input": {"action": "retrieve", "query": "test query", "user_id": "test_user"},
700+
}.get(key, default)
701+
702+
# Mock data for search_memories response
703+
retrieve_response = {
704+
"results": [
705+
{
706+
"id": "mem123",
707+
"memory": "Test memory content",
708+
"score": 0.85,
709+
"created_at": "2024-03-20T10:00:00Z",
710+
"user_id": "test_user",
711+
"metadata": {"category": "test"},
712+
}
713+
]
714+
}
715+
716+
# Configure mocks
717+
mock_mem0_service_client.search_memories.return_value = retrieve_response
718+
719+
# Call the memory function
720+
result = mem0_memory.mem0_memory(tool=mock_tool)
721+
722+
# Assertions
723+
assert result["status"] == "success"
724+
assert isinstance(result["content"], list)
725+
assert len(result["content"]) > 0
726+
assert "text" in result["content"][0]
727+
# Parse the JSON string in text
728+
memories = json.loads(result["content"][0]["text"])
729+
assert isinstance(memories, list)
730+
assert len(memories) > 0
731+
assert "id" in memories[0]
732+
assert memories[0]["id"] == "mem123"
733+
734+
735+
@patch.dict(
736+
os.environ,
737+
{
738+
"POSTGRESQL_HOST": "test-cluster.cluster-abc123.us-west-2.rds.amazonaws.com",
739+
"POSTGRESQL_USER": "test_user",
740+
"POSTGRESQL_PASSWORD": "test_password",
741+
"DB_NAME": "test_db",
742+
"BYPASS_TOOL_CONSENT": "true",
743+
},
744+
)
745+
@patch("strands_tools.mem0_memory.Mem0ServiceClient")
746+
def test_postgresql_delete_memory(mock_mem0_client, mock_mem0_service_client, mock_tool):
747+
"""Test PostgreSQL delete memory functionality with BYPASS_TOOL_CONSENT mode enabled."""
748+
# Setup mocks
749+
mock_mem0_client.return_value = mock_mem0_service_client
750+
751+
# Configure the mock_tool
752+
mock_tool.get.side_effect = lambda key, default=None: {
753+
"toolUseId": "test-id",
754+
"input": {"action": "delete", "memory_id": "mem123"},
755+
}.get(key, default)
756+
757+
# Configure mocks
758+
mock_mem0_service_client.delete_memory.return_value = {"status": "success"}
759+
760+
# Call the memory function
761+
result = mem0_memory.mem0_memory(tool=mock_tool)
762+
763+
# Assertions
764+
assert result["status"] == "success"
765+
assert "Memory mem123 deleted successfully" in str(result["content"][0]["text"])
766+
767+
# Verify correct functions were called
768+
mock_mem0_service_client.delete_memory.assert_called_once()
769+
call_args = mock_mem0_service_client.delete_memory.call_args[0]
770+
assert call_args[0] == "mem123"
771+
772+
773+
@patch.dict(
774+
os.environ,
775+
{
776+
"POSTGRESQL_HOST": "test-cluster.cluster-abc123.us-west-2.rds.amazonaws.com",
777+
"POSTGRESQL_USER": "test_user",
778+
# Missing POSTGRESQL_PASSWORD
779+
"MEM0_LLM_PROVIDER": "openai",
780+
"OPENAI_API_KEY": "test-api-key",
781+
},
782+
)
783+
def test_postgresql_missing_required_vars(mock_tool):
784+
"""Test PostgreSQL client with missing required environment variables."""
785+
# Configure the mock_tool
786+
mock_tool.get.side_effect = lambda key, default=None: {
787+
"toolUseId": "test-id",
788+
"input": {
789+
"action": "store",
790+
"content": "Test memory content",
791+
"user_id": "test_user",
792+
},
793+
}.get(key, default)
794+
795+
# Call the memory function
796+
result = mem0_memory.mem0_memory(tool=mock_tool)
797+
798+
# Assertions
799+
assert result["status"] == "error"
800+
assert "Missing required PostgreSQL environment variables" in str(result["content"][0]["text"])
801+
assert "POSTGRESQL_PASSWORD" in str(result["content"][0]["text"])

0 commit comments

Comments
 (0)