Skip to content

Commit a4ee506

Browse files
author
Lili Ma
committed
Add mem0_memory Support for PostgreSQL
1 parent 55354a1 commit a4ee506

File tree

3 files changed

+353
-4
lines changed

3 files changed

+353
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ mem0_memory = [
8282
# Need to be optional as a fix for https://github.com/strands-agents/docs/issues/19
8383
"mem0ai>=0.1.99,<1.0.0",
8484
"opensearch-py>=2.8.0,<3.0.0",
85+
"psycopg2-binary",
8586
]
8687
local_chromium_browser = ["nest-asyncio>=1.5.0,<2.0.0", "playwright>=1.42.0,<2.0.0"]
8788
agent_core_browser = [

src/strands_tools/mem0_memory.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,30 @@ class Mem0ServiceClient:
177177
},
178178
}
179179

180+
def _get_postgresql_config(self) -> Dict:
181+
"""Get PostgreSQL configuration based on the current provider."""
182+
# Start with the default embedder and llm config
183+
config = {
184+
"embedder": self.DEFAULT_CONFIG["embedder"].copy(),
185+
"llm": self.DEFAULT_CONFIG["llm"].copy(),
186+
}
187+
188+
# Add PostgreSQL vector store configuration
189+
config["vector_store"] = {
190+
"provider": "pgvector",
191+
"config": {
192+
"host": os.environ.get("POSTGRESQL_HOST"),
193+
"port": int(os.environ.get("POSTGRESQL_PORT", 5432)),
194+
"user": os.environ.get("POSTGRESQL_USER"),
195+
"password": os.environ.get("POSTGRESQL_PASSWORD"),
196+
"dbname": os.environ.get("DB_NAME", "postgres"),
197+
"collection_name": os.environ.get("DB_COLLECTION_NAME", "mem0_memories"),
198+
"embedding_model_dims": 1024,
199+
},
200+
}
201+
202+
return config
203+
180204
def __init__(self, config: Optional[Dict] = None):
181205
"""Initialize the Mem0 service client.
182206
@@ -208,6 +232,10 @@ def _initialize_client(self, config: Optional[Dict] = None) -> Any:
208232
logger.debug("Using Neptune Analytics graph backend (Mem0Memory with Neptune Analytics)")
209233
config = self._configure_neptune_analytics_backend(config)
210234

235+
if os.environ.get("POSTGRESQL_HOST"):
236+
logger.info("Using PostgreSQL backend (Mem0Memory with PostgreSQL)")
237+
return self._initialize_postgresql_client(config)
238+
211239
if os.environ.get("OPENSEARCH_HOST"):
212240
logger.debug("Using OpenSearch backend (Mem0Memory with OpenSearch)")
213241
return self._initialize_opensearch_client(config)
@@ -231,6 +259,37 @@ def _configure_neptune_analytics_backend(self, config: Optional[Dict] = None) ->
231259
}
232260
return config
233261

262+
def _initialize_postgresql_client(self, config: Optional[Dict] = None) -> Mem0Memory:
263+
"""Initialize a Mem0 client with PostgreSQL backend.
264+
265+
Args:
266+
config: Optional configuration dictionary to override defaults.
267+
268+
Returns:
269+
An initialized Mem0Memory instance configured for PostgreSQL.
270+
271+
Raises:
272+
ValueError: If required PostgreSQL environment variables are missing.
273+
"""
274+
# Validate required environment variables
275+
required_vars = ["POSTGRESQL_HOST", "POSTGRESQL_USER", "POSTGRESQL_PASSWORD"]
276+
missing_vars = [var for var in required_vars if not os.environ.get(var)]
277+
if missing_vars:
278+
raise ValueError(f"Missing required PostgreSQL environment variables: {', '.join(missing_vars)}")
279+
280+
# Get PostgreSQL configuration
281+
pg_config = self._get_postgresql_config()
282+
283+
# Validate OpenAI API key if using OpenAI
284+
provider = os.environ.get("MEM0_LLM_PROVIDER", "aws_bedrock")
285+
if provider == "openai" and not os.environ.get("OPENAI_API_KEY"):
286+
raise ValueError("OPENAI_API_KEY environment variable is required when using OpenAI provider")
287+
288+
# Merge with user-provided config if any
289+
merged_config = self._merge_configs(pg_config, config)
290+
291+
return Mem0Memory.from_config(config_dict=merged_config)
292+
234293
def _initialize_opensearch_client(self, config: Optional[Dict] = None) -> Mem0Memory:
235294
"""Initialize a Mem0 client with OpenSearch backend.
236295
@@ -296,12 +355,24 @@ def _merge_config(self, config: Optional[Dict] = None) -> Dict:
296355
Returns:
297356
A merged configuration dictionary.
298357
"""
299-
merged_config = self.DEFAULT_CONFIG.copy()
300-
if not config:
358+
return self._merge_configs(self.DEFAULT_CONFIG, config)
359+
360+
def _merge_configs(self, base_config: Dict, override_config: Optional[Dict] = None) -> Dict:
361+
"""Merge two configuration dictionaries.
362+
363+
Args:
364+
base_config: Base configuration dictionary
365+
override_config: Optional configuration to merge into base
366+
367+
Returns:
368+
A merged configuration dictionary.
369+
"""
370+
merged_config = base_config.copy()
371+
if not override_config:
301372
return merged_config
302373

303-
# Deep merge the configs
304-
for key, value in config.items():
374+
# Merge the configs
375+
for key, value in override_config.items():
305376
if key in merged_config and isinstance(value, dict) and isinstance(merged_config[key], dict):
306377
merged_config[key].update(value)
307378
else:

0 commit comments

Comments
 (0)