diff --git a/mem0/configs/vector_stores/opensearch.py b/mem0/configs/vector_stores/opensearch.py index 2103540ffe..3a24006180 100644 --- a/mem0/configs/vector_stores/opensearch.py +++ b/mem0/configs/vector_stores/opensearch.py @@ -14,6 +14,7 @@ class OpenSearchConfig(BaseModel): verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)") use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)") auto_create_index: bool = Field(True, description="Automatically create index during initialization") + http_auth: Optional[object] = Field(None, description="HTTP authentication method / AWS SigV4") @model_validator(mode="before") @classmethod @@ -23,7 +24,7 @@ def validate_auth(cls, values: Dict[str, Any]) -> Dict[str, Any]: raise ValueError("Host must be provided for OpenSearch") # Authentication: Either API key or user/password must be provided - if not any([values.get("api_key"), (values.get("user") and values.get("password"))]): + if not any([values.get("api_key"), (values.get("user") and values.get("password")), values.get("http_auth")]): raise ValueError("Either api_key or user/password must be provided for OpenSearch authentication") return values diff --git a/mem0/vector_stores/opensearch.py b/mem0/vector_stores/opensearch.py index 159ec9124c..2a58ac456c 100644 --- a/mem0/vector_stores/opensearch.py +++ b/mem0/vector_stores/opensearch.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional try: - from opensearchpy import OpenSearch + from opensearchpy import OpenSearch, RequestsHttpConnection from opensearchpy.helpers import bulk except ImportError: raise ImportError("OpenSearch requires extra dependencies. Install with `pip install opensearch-py`") from None @@ -28,9 +28,10 @@ def __init__(self, **kwargs): # Initialize OpenSearch client self.client = OpenSearch( hosts=[{"host": config.host, "port": config.port or 9200}], - http_auth=(config.user, config.password) if (config.user and config.password) else None, + http_auth=config.http_auth if config.http_auth else ((config.user, config.password) if (config.user and config.password) else None), use_ssl=config.use_ssl, verify_certs=config.verify_certs, + connection_class=RequestsHttpConnection ) self.collection_name = config.collection_name @@ -43,14 +44,17 @@ def __init__(self, **kwargs): def create_index(self) -> None: """Create OpenSearch index with proper mappings if it doesn't exist.""" index_settings = { - # ToDo change replicas to 1 "settings": { "index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "1s", "knn": True} }, "mappings": { "properties": { "text": {"type": "text"}, - "vector": {"type": "knn_vector", "dimension": self.vector_dim}, + "vector": { + "type": "knn_vector", + "dimension": self.vector_dim, + "method": {"engine": "lucene", "name": "hnsw", "space_type": "cosinesimil"}, + }, "metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}}, } }, diff --git a/tests/vector_stores/test_opensearch.py b/tests/vector_stores/test_opensearch.py index 40d0959ba0..912b660d2a 100644 --- a/tests/vector_stores/test_opensearch.py +++ b/tests/vector_stores/test_opensearch.py @@ -5,7 +5,7 @@ import dotenv try: - from opensearchpy import OpenSearch + from opensearchpy import OpenSearch, AWSV4SignerAuth except ImportError: raise ImportError( "OpenSearch requires extra dependencies. Install with `pip install opensearch-py`" @@ -148,3 +148,29 @@ def test_delete(self): def test_delete_col(self): self.os_db.delete_col() self.client_mock.indices.delete.assert_called_once_with(index="test_collection") + + + def test_init_with_http_auth(self): + mock_credentials = MagicMock() + mock_signer = AWSV4SignerAuth(mock_credentials, "us-east-1", "es") + + with patch('mem0.vector_stores.opensearch.OpenSearch') as mock_opensearch: + test_db = OpenSearchDB( + host="localhost", + port=9200, + collection_name="test_collection", + embedding_model_dims=1536, + http_auth=mock_signer, + verify_certs=True, + use_ssl=True, + auto_create_index=False + ) + + # Verify OpenSearch was initialized with correct params + mock_opensearch.assert_called_once_with( + hosts=[{"host": "localhost", "port": 9200}], + http_auth=mock_signer, + use_ssl=True, + verify_certs=True, + connection_class=unittest.mock.ANY + ) \ No newline at end of file