Skip to content

Commit e567144

Browse files
Lili Maliangchg
andcommitted
Add mem0_memory Support for PostgreSQL
Co-authored-by: liangchg <[email protected]>
1 parent 66fe044 commit e567144

File tree

3 files changed

+356
-5
lines changed

3 files changed

+356
-5
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ mem0_memory = [
8282
# Need to be optional as a fix for https://github.com/strands-agents/docs/issues/19
8383
"mem0ai>=0.1.99,<1.0.0",
8484
"opensearch-py>=2.8.0,<3.0.0",
85+
"psycopg2-binary",
8586
]
8687
local_chromium_browser = ["nest-asyncio>=1.5.0,<2.0.0", "playwright>=1.42.0,<2.0.0"]
8788
agent_core_browser = [

src/strands_tools/mem0_memory.py

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ class Mem0ServiceClient:
166166
"provider": "opensearch",
167167
"config": {
168168
"port": 443,
169-
"collection_name": "mem0_memories",
169+
"collection_name": os.environ.get("OPENSEARCH_COLLECTION", "mem0"),
170170
"host": os.environ.get("OPENSEARCH_HOST"),
171171
"embedding_model_dims": 1024,
172172
"connection_class": RequestsHttpConnection,
@@ -177,6 +177,33 @@ class Mem0ServiceClient:
177177
},
178178
}
179179

180+
def _get_postgresql_config(self) -> Dict:
181+
"""Get PostgreSQL configuration based on the current provider."""
182+
# Start with the default embedder and llm config
183+
config = {
184+
"embedder": self.DEFAULT_CONFIG["embedder"].copy(),
185+
"llm": self.DEFAULT_CONFIG["llm"].copy(),
186+
}
187+
188+
189+
190+
# Add PostgreSQL vector store configuration
191+
config["vector_store"] = {
192+
"provider": "pgvector",
193+
"config": {
194+
"host": os.environ.get("POSTGRESQL_HOST"),
195+
"port": int(os.environ.get("POSTGRESQL_PORT", 5432)),
196+
"user": os.environ.get("POSTGRESQL_USER"),
197+
"password": os.environ.get("POSTGRESQL_PASSWORD"),
198+
"dbname": os.environ.get("DB_NAME", "postgres"),
199+
"collection_name": os.environ.get("DB_COLLECTION_NAME", "mem0_memories"),
200+
"embedding_model_dims": 1024,
201+
}
202+
}
203+
204+
return config
205+
206+
180207
def __init__(self, config: Optional[Dict] = None):
181208
"""Initialize the Mem0 service client.
182209
@@ -208,6 +235,10 @@ def _initialize_client(self, config: Optional[Dict] = None) -> Any:
208235
logger.debug("Using Neptune Analytics graph backend (Mem0Memory with Neptune Analytics)")
209236
config = self._configure_neptune_analytics_backend(config)
210237

238+
if os.environ.get("POSTGRESQL_HOST"):
239+
logger.info("Using PostgreSQL backend (Mem0Memory with PostgreSQL)")
240+
return self._initialize_postgresql_client(config)
241+
211242
if os.environ.get("OPENSEARCH_HOST"):
212243
logger.debug("Using OpenSearch backend (Mem0Memory with OpenSearch)")
213244
return self._initialize_opensearch_client(config)
@@ -231,6 +262,37 @@ def _configure_neptune_analytics_backend(self, config: Optional[Dict] = None) ->
231262
}
232263
return config
233264

265+
def _initialize_postgresql_client(self, config: Optional[Dict] = None) -> Mem0Memory:
266+
"""Initialize a Mem0 client with PostgreSQL backend.
267+
268+
Args:
269+
config: Optional configuration dictionary to override defaults.
270+
271+
Returns:
272+
An initialized Mem0Memory instance configured for PostgreSQL.
273+
274+
Raises:
275+
ValueError: If required PostgreSQL environment variables are missing.
276+
"""
277+
# Validate required environment variables
278+
required_vars = ["POSTGRESQL_HOST", "POSTGRESQL_USER", "POSTGRESQL_PASSWORD"]
279+
missing_vars = [var for var in required_vars if not os.environ.get(var)]
280+
if missing_vars:
281+
raise ValueError(f"Missing required PostgreSQL environment variables: {', '.join(missing_vars)}")
282+
283+
# Get PostgreSQL configuration
284+
pg_config = self._get_postgresql_config()
285+
286+
# Validate OpenAI API key if using OpenAI
287+
provider = os.environ.get("MEM0_LLM_PROVIDER", "aws_bedrock")
288+
if provider == "openai" and not os.environ.get("OPENAI_API_KEY"):
289+
raise ValueError("OPENAI_API_KEY environment variable is required when using OpenAI provider")
290+
291+
# Merge with user-provided config if any
292+
merged_config = self._merge_configs(pg_config, config)
293+
294+
return Mem0Memory.from_config(config_dict=merged_config)
295+
234296
def _initialize_opensearch_client(self, config: Optional[Dict] = None) -> Mem0Memory:
235297
"""Initialize a Mem0 client with OpenSearch backend.
236298
@@ -296,12 +358,24 @@ def _merge_config(self, config: Optional[Dict] = None) -> Dict:
296358
Returns:
297359
A merged configuration dictionary.
298360
"""
299-
merged_config = self.DEFAULT_CONFIG.copy()
300-
if not config:
361+
return self._merge_configs(self.DEFAULT_CONFIG, config)
362+
363+
def _merge_configs(self, base_config: Dict, override_config: Optional[Dict] = None) -> Dict:
364+
"""Merge two configuration dictionaries.
365+
366+
Args:
367+
base_config: Base configuration dictionary
368+
override_config: Optional configuration to merge into base
369+
370+
Returns:
371+
A merged configuration dictionary.
372+
"""
373+
merged_config = base_config.copy()
374+
if not override_config:
301375
return merged_config
302376

303-
# Deep merge the configs
304-
for key, value in config.items():
377+
# Merge the configs
378+
for key, value in override_config.items():
305379
if key in merged_config and isinstance(value, dict) and isinstance(merged_config[key], dict):
306380
merged_config[key].update(value)
307381
else:

0 commit comments

Comments
 (0)