Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/fix opensearch vector mapping #2399

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mem0/configs/vector_stores/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class OpenSearchConfig(BaseModel):
verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)")
use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)")
auto_create_index: bool = Field(True, description="Automatically create index during initialization")
http_auth: Optional[object] = Field(None, description="HTTP authentication method / AWS SigV4")

@model_validator(mode="before")
@classmethod
Expand All @@ -23,7 +24,7 @@ def validate_auth(cls, values: Dict[str, Any]) -> Dict[str, Any]:
raise ValueError("Host must be provided for OpenSearch")

# Authentication: Either API key or user/password must be provided
if not any([values.get("api_key"), (values.get("user") and values.get("password"))]):
if not any([values.get("api_key"), (values.get("user") and values.get("password")), values.get("http_auth")]):
raise ValueError("Either api_key or user/password must be provided for OpenSearch authentication")

return values
Expand Down
12 changes: 8 additions & 4 deletions mem0/vector_stores/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, List, Optional

try:
from opensearchpy import OpenSearch
from opensearchpy import OpenSearch, RequestsHttpConnection
from opensearchpy.helpers import bulk
except ImportError:
raise ImportError("OpenSearch requires extra dependencies. Install with `pip install opensearch-py`") from None
Expand All @@ -28,9 +28,10 @@ def __init__(self, **kwargs):
# Initialize OpenSearch client
self.client = OpenSearch(
hosts=[{"host": config.host, "port": config.port or 9200}],
http_auth=(config.user, config.password) if (config.user and config.password) else None,
http_auth=config.http_auth if config.http_auth else ((config.user, config.password) if (config.user and config.password) else None),
use_ssl=config.use_ssl,
verify_certs=config.verify_certs,
connection_class=RequestsHttpConnection
)

self.collection_name = config.collection_name
Expand All @@ -43,14 +44,17 @@ def __init__(self, **kwargs):
def create_index(self) -> None:
"""Create OpenSearch index with proper mappings if it doesn't exist."""
index_settings = {
# ToDo change replicas to 1
"settings": {
"index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "1s", "knn": True}
},
"mappings": {
"properties": {
"text": {"type": "text"},
"vector": {"type": "knn_vector", "dimension": self.vector_dim},
"vector": {
"type": "knn_vector",
"dimension": self.vector_dim,
"method": {"engine": "lucene", "name": "hnsw", "space_type": "cosinesimil"},
},
"metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}},
}
},
Expand Down
28 changes: 27 additions & 1 deletion tests/vector_stores/test_opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import dotenv

try:
from opensearchpy import OpenSearch
from opensearchpy import OpenSearch, AWSV4SignerAuth
except ImportError:
raise ImportError(
"OpenSearch requires extra dependencies. Install with `pip install opensearch-py`"
Expand Down Expand Up @@ -148,3 +148,29 @@ def test_delete(self):
def test_delete_col(self):
self.os_db.delete_col()
self.client_mock.indices.delete.assert_called_once_with(index="test_collection")


def test_init_with_http_auth(self):
mock_credentials = MagicMock()
mock_signer = AWSV4SignerAuth(mock_credentials, "us-east-1", "es")

with patch('mem0.vector_stores.opensearch.OpenSearch') as mock_opensearch:
test_db = OpenSearchDB(
host="localhost",
port=9200,
collection_name="test_collection",
embedding_model_dims=1536,
http_auth=mock_signer,
verify_certs=True,
use_ssl=True,
auto_create_index=False
)

# Verify OpenSearch was initialized with correct params
mock_opensearch.assert_called_once_with(
hosts=[{"host": "localhost", "port": 9200}],
http_auth=mock_signer,
use_ssl=True,
verify_certs=True,
connection_class=unittest.mock.ANY
)
Loading