From a24d8d25b9d2005b4b07d300a99dfee2400098a1 Mon Sep 17 00:00:00 2001 From: Farzad Date: Mon, 3 Feb 2025 10:00:07 -0600 Subject: [PATCH 1/4] azure ai search enhancements --- mem0/vector_stores/azure_ai_search.py | 116 +++++++++---- poetry.lock | 62 ++++++- tests/vector_stores/test_azure_ai_search.py | 182 ++++++++++++++++++++ 3 files changed, 321 insertions(+), 39 deletions(-) create mode 100644 tests/vector_stores/test_azure_ai_search.py diff --git a/mem0/vector_stores/azure_ai_search.py b/mem0/vector_stores/azure_ai_search.py index c7d5cb2d4b..81017ca0ae 100644 --- a/mem0/vector_stores/azure_ai_search.py +++ b/mem0/vector_stores/azure_ai_search.py @@ -11,6 +11,9 @@ from azure.core.exceptions import ResourceNotFoundError from azure.search.documents import SearchClient from azure.search.documents.indexes import SearchIndexClient + from azure.search.documents.indexes.models import ( + BinaryQuantizationCompression, # Added for binary quantization + ) from azure.search.documents.indexes.models import ( HnswAlgorithmConfiguration, ScalarQuantizationCompression, @@ -24,7 +27,7 @@ from azure.search.documents.models import VectorizedQuery except ImportError: raise ImportError( - "The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.1'." + "The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.2'." ) logger = logging.getLogger(__name__) @@ -37,42 +40,73 @@ class OutputData(BaseModel): class AzureAISearch(VectorStoreBase): - def __init__(self, service_name, collection_name, api_key, embedding_model_dims, use_compression): - """Initialize the Azure Cognitive Search vector store. + def __init__( + self, + service_name, + collection_name, + api_key, + embedding_model_dims, + compression_type="none", # "none", "scalar", or "binary" + use_float16=False, + ): + """ + Initialize the Azure AI Search vector store. Args: - service_name (str): Azure Cognitive Search service name. + service_name (str): Azure AI Search service name. collection_name (str): Index name. - api_key (str): API key for the Azure Cognitive Search service. + api_key (str): API key for the Azure AI Search service. embedding_model_dims (int): Dimension of the embedding vector. - use_compression (bool): Use scalar quantization vector compression + compression_type (str): Specifies the type of quantization to use. + Allowed values are "none", "scalar", or "binary". + use_float16 (bool): Whether to store vectors in half precision (Edm.Half) or full precision (Edm.Single). """ self.index_name = collection_name self.collection_name = collection_name self.embedding_model_dims = embedding_model_dims - self.use_compression = use_compression + self.compression_type = compression_type.lower() + self.use_float16 = use_float16 + self.search_client = SearchClient( endpoint=f"https://{service_name}.search.windows.net", index_name=self.index_name, credential=AzureKeyCredential(api_key), ) self.index_client = SearchIndexClient( - endpoint=f"https://{service_name}.search.windows.net", credential=AzureKeyCredential(api_key) + endpoint=f"https://{service_name}.search.windows.net", + credential=AzureKeyCredential(api_key), ) + + # Inject custom UserAgent header ("mem0") for tracking indexes created via mem0. + try: + self.search_client._client._config.user_agent_policy.add_user_agent("mem0") + self.index_client._client._config.user_agent_policy.add_user_agent("mem0") + except Exception as e: + logger.warning(f"Failed to add custom UserAgent header: {e}") + self.create_col() # create the collection / index def create_col(self): - """Create a new index in Azure Cognitive Search.""" - vector_dimensions = self.embedding_model_dims # Set this to the number of dimensions in your vector - - if self.use_compression: + """Create a new index in Azure AI Search.""" + # Determine vector type based on use_float16 setting. + if self.use_float16: vector_type = "Collection(Edm.Half)" - compression_name = "myCompression" - compression_configurations = [ScalarQuantizationCompression(compression_name=compression_name)] else: vector_type = "Collection(Edm.Single)" - compression_name = None - compression_configurations = [] + + # Configure compression settings based on the specified compression_type. + compression_configurations = [] + compression_name = None + if self.compression_type == "scalar": + compression_name = "myCompression" + compression_configurations = [ + ScalarQuantizationCompression(compression_name=compression_name) + ] + elif self.compression_type == "binary": + compression_name = "myCompression" + compression_configurations = [ + BinaryQuantizationCompression(compression_name=compression_name) + ] fields = [ SimpleField(name="id", type=SearchFieldDataType.String, key=True), @@ -82,8 +116,8 @@ def create_col(self): SearchField( name="vector", type=vector_type, - searchable=True, - vector_search_dimensions=vector_dimensions, + searchable=True, # Tehnically don't need this on SearchField but leaving for visibility + vector_search_dimensions=self.embedding_model_dims, vector_search_profile_name="my-vector-config", ), SimpleField(name="payload", type=SearchFieldDataType.String, searchable=True), @@ -101,14 +135,16 @@ def create_col(self): def _generate_document(self, vector, payload, id): document = {"id": id, "vector": vector, "payload": json.dumps(payload)} - # Extract additional fields if they exist + # Extract additional fields if they exist. for field in ["user_id", "run_id", "agent_id"]: if field in payload: document[field] = payload[field] return document + # Note: Explicit "insert" calls may later be decoupled from memory management decisions. def insert(self, vectors, payloads=None, ids=None): - """Insert vectors into the index. + """ + Insert vectors into the index. Args: vectors (List[List[float]]): List of vectors to insert. @@ -116,7 +152,6 @@ def insert(self, vectors, payloads=None, ids=None): ids (List[str], optional): List of IDs corresponding to vectors. """ logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}") - documents = [ self._generate_document(vector, payload, id) for id, vector, payload in zip(ids, vectors, payloads) @@ -126,28 +161,28 @@ def insert(self, vectors, payloads=None, ids=None): def _build_filter_expression(self, filters): filter_conditions = [] for key, value in filters.items(): - # If the value is a string, add quotes if isinstance(value, str): condition = f"{key} eq '{value}'" else: condition = f"{key} eq {value}" filter_conditions.append(condition) - # Use 'and' to join multiple conditions - filter_expression = ' and '.join(filter_conditions) + filter_expression = " and ".join(filter_conditions) return filter_expression - def search(self, query, limit=5, filters=None): - """Search for similar vectors. + def search(self, query, limit=5, filters=None, vector_filter_mode="preFilter"): + """ + Search for similar vectors. Args: - query (List[float]): Query vectors. + query (List[float]): Query vector. limit (int, optional): Number of results to return. Defaults to 5. filters (Dict, optional): Filters to apply to the search. Defaults to None. + vector_filter_mode (str): Determines whether filters are applied before or after the vector search. + Known values: "preFilter" (default) and "postFilter". Returns: list: Search results. """ - # Build filter expression filter_expression = None if filters: filter_expression = self._build_filter_expression(filters) @@ -155,10 +190,12 @@ def search(self, query, limit=5, filters=None): vector_query = VectorizedQuery( vector=query, k_nearest_neighbors=limit, fields="vector" ) + # Pass vector_filter_mode to the search call. search_results = self.search_client.search( vector_queries=[vector_query], filter=filter_expression, - top=limit + top=limit, + vector_filter_mode=vector_filter_mode # New query parameter for filter mode. ) results = [] @@ -168,7 +205,8 @@ def search(self, query, limit=5, filters=None): return results def delete(self, vector_id): - """Delete a vector by ID. + """ + Delete a vector by ID. Args: vector_id (str): ID of the vector to delete. @@ -177,7 +215,8 @@ def delete(self, vector_id): logger.info(f"Deleted document with ID '{vector_id}' from index '{self.index_name}'.") def update(self, vector_id, vector=None, payload=None): - """Update a vector and its payload. + """ + Update a vector and its payload. Args: vector_id (str): ID of the vector to update. @@ -195,7 +234,8 @@ def update(self, vector_id, vector=None, payload=None): self.search_client.merge_or_upload_documents(documents=[document]) def get(self, vector_id) -> OutputData: - """Retrieve a vector by ID. + """ + Retrieve a vector by ID. Args: vector_id (str): ID of the vector to retrieve. @@ -210,7 +250,8 @@ def get(self, vector_id) -> OutputData: return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"])) def list_cols(self) -> List[str]: - """List all collections (indexes). + """ + List all collections (indexes). Returns: List[str]: List of index names. @@ -223,7 +264,8 @@ def delete_col(self): self.index_client.delete_index(self.index_name) def col_info(self): - """Get information about the index. + """ + Get information about the index. Returns: Dict[str, Any]: Index information. @@ -232,7 +274,8 @@ def col_info(self): return {"name": index.name, "fields": index.fields} def list(self, filters=None, limit=100): - """List all vectors in the index. + """ + List all vectors in the index. Args: filters (Dict, optional): Filters to apply to the list. @@ -254,8 +297,7 @@ def list(self, filters=None, limit=100): for result in search_results: payload = json.loads(result["payload"]) results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload)) - - return [results] + return results def __del__(self): """Close the search client when the object is deleted.""" diff --git a/poetry.lock b/poetry.lock index 08c2602c83..26e2dff817 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -186,6 +186,53 @@ docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphi tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] +[[package]] +name = "azure-common" +version = "1.1.28" +description = "Microsoft Azure Client Library for Python (Common)" +optional = false +python-versions = "*" +files = [ + {file = "azure-common-1.1.28.zip", hash = "sha256:4ac0cd3214e36b6a1b6a442686722a5d8cc449603aa833f3f0f40bda836704a3"}, + {file = "azure_common-1.1.28-py2.py3-none-any.whl", hash = "sha256:5c12d3dcf4ec20599ca6b0d3e09e86e146353d443e7fcc050c9a19c1f9df20ad"}, +] + +[[package]] +name = "azure-core" +version = "1.32.0" +description = "Microsoft Azure Core Library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "azure_core-1.32.0-py3-none-any.whl", hash = "sha256:eac191a0efb23bfa83fddf321b27b122b4ec847befa3091fa736a5c32c50d7b4"}, + {file = "azure_core-1.32.0.tar.gz", hash = "sha256:22b3c35d6b2dae14990f6c1be2912bf23ffe50b220e708a28ab1bb92b1c730e5"}, +] + +[package.dependencies] +requests = ">=2.21.0" +six = ">=1.11.0" +typing-extensions = ">=4.6.0" + +[package.extras] +aio = ["aiohttp (>=3.0)"] + +[[package]] +name = "azure-search-documents" +version = "11.5.2" +description = "Microsoft Azure Cognitive Search Client Library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "azure_search_documents-11.5.2-py3-none-any.whl", hash = "sha256:c949d011008a4b0bcee3db91132741b4e4d50ddb3f7e2f48944d949d4b413b11"}, + {file = "azure_search_documents-11.5.2.tar.gz", hash = "sha256:98977dd1fa4978d3b7d8891a0856b3becb6f02cc07ff2e1ea40b9c7254ada315"}, +] + +[package.dependencies] +azure-common = ">=1.1" +azure-core = ">=1.28.0" +isodate = ">=0.6.0" +typing-extensions = ">=4.6.0" + [[package]] name = "backoff" version = "2.2.1" @@ -813,6 +860,17 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "isodate" +version = "0.7.2" +description = "An ISO 8601 date/time/duration parser and formatter" +optional = false +python-versions = ">=3.7" +files = [ + {file = "isodate-0.7.2-py3-none-any.whl", hash = "sha256:28009937d8031054830160fce6d409ed342816b543597cece116d966c6d99e15"}, + {file = "isodate-0.7.2.tar.gz", hash = "sha256:4cd1aa0f43ca76f4a6c6c0292a85f40b35ec2e43e315b59f06e6d32171a953e6"}, +] + [[package]] name = "isort" version = "5.13.2" @@ -2397,4 +2455,4 @@ graph = ["langchain-community", "neo4j", "rank-bm25"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "6dba8e3091ed9a081a2035313e674206e2bee10072cb5c0ce6fb407686ce5156" +content-hash = "91b6996f9590235d0a4c6504cb4b9aafe68a3161e204a0fd9ea11b2a0c467cc1" diff --git a/tests/vector_stores/test_azure_ai_search.py b/tests/vector_stores/test_azure_ai_search.py new file mode 100644 index 0000000000..a7e23662ac --- /dev/null +++ b/tests/vector_stores/test_azure_ai_search.py @@ -0,0 +1,182 @@ +import json +from unittest.mock import Mock, patch + +import pytest + +# Import the AzureAISearch class and OutputData model from your module. +from mem0.vector_stores.azure_ai_search import AzureAISearch + + +# Fixture to patch SearchClient and SearchIndexClient and create an instance of AzureAISearch. +@pytest.fixture +def mock_clients(): + with patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient, \ + patch("mem0.vector_stores.azure_ai_search.SearchIndexClient") as MockIndexClient: + # Create mocked instances for search and index clients. + mock_search_client = MockSearchClient.return_value + mock_index_client = MockIndexClient.return_value + + # Stub required methods on search_client. + mock_search_client.upload_documents = Mock() + mock_search_client.search = Mock() + mock_search_client.delete_documents = Mock() + mock_search_client.merge_or_upload_documents = Mock() + mock_search_client.get_document = Mock() + mock_search_client.close = Mock() + + # Stub required methods on index_client. + mock_index_client.create_or_update_index = Mock() + mock_index_client.list_indexes = Mock(return_value=[]) + mock_index_client.delete_index = Mock() + # For col_info() we assume get_index returns an object with name and fields attributes. + fake_index = Mock() + fake_index.name = "test-index" + fake_index.fields = ["id", "vector", "payload"] + mock_index_client.get_index = Mock(return_value=fake_index) + mock_index_client.close = Mock() + + yield mock_search_client, mock_index_client + +@pytest.fixture +def azure_ai_search_instance(mock_clients): + mock_search_client, mock_index_client = mock_clients + # Create an instance with dummy parameters. + instance = AzureAISearch( + service_name="test-service", + collection_name="test-index", + api_key="test-api-key", + embedding_model_dims=3, + compression_type="binary", # testing binary quantization option + use_float16=True + ) + # Return instance and clients for verification. + return instance, mock_search_client, mock_index_client + +def test_create_col(azure_ai_search_instance): + instance, mock_search_client, mock_index_client = azure_ai_search_instance + # Upon initialization, create_col should be called. + mock_index_client.create_or_update_index.assert_called_once() + # Optionally, you could inspect the call arguments for vector type. + +def test_insert(azure_ai_search_instance): + instance, mock_search_client, _ = azure_ai_search_instance + vectors = [[0.1, 0.2, 0.3]] + payloads = [{"user_id": "user1", "run_id": "run1"}] + ids = ["doc1"] + + instance.insert(vectors, payloads, ids) + + mock_search_client.upload_documents.assert_called_once() + args, _ = mock_search_client.upload_documents.call_args + documents = args[0] + # Update expected_doc to include extra fields from payload. + expected_doc = { + "id": "doc1", + "vector": [0.1, 0.2, 0.3], + "payload": json.dumps({"user_id": "user1", "run_id": "run1"}), + "user_id": "user1", + "run_id": "run1" + } + assert documents[0] == expected_doc + +def test_search_preFilter(azure_ai_search_instance): + instance, mock_search_client, _ = azure_ai_search_instance + # Setup a fake search result returned by the mocked search method. + fake_result = { + "id": "doc1", + "@search.score": 0.95, + "payload": json.dumps({"user_id": "user1"}) + } + # Configure the mock to return an iterator (e.g., a list) with fake_result. + mock_search_client.search.return_value = [fake_result] + + query_vector = [0.1, 0.2, 0.3] + results = instance.search(query_vector, limit=1, filters={"user_id": "user1"}, vector_filter_mode="preFilter") + + # Verify that the search method was called with vector_filter_mode="preFilter". + mock_search_client.search.assert_called_once() + _, called_kwargs = mock_search_client.search.call_args + assert called_kwargs.get("vector_filter_mode") == "preFilter" + + # Verify that the output is parsed correctly. + assert len(results) == 1 + assert results[0].id == "doc1" + assert results[0].score == 0.95 + assert results[0].payload == {"user_id": "user1"} + +def test_search_postFilter(azure_ai_search_instance): + instance, mock_search_client, _ = azure_ai_search_instance + # Setup a fake search result for postFilter. + fake_result = { + "id": "doc2", + "@search.score": 0.85, + "payload": json.dumps({"user_id": "user2"}) + } + mock_search_client.search.return_value = [fake_result] + + query_vector = [0.4, 0.5, 0.6] + results = instance.search(query_vector, limit=1, filters={"user_id": "user2"}, vector_filter_mode="postFilter") + + mock_search_client.search.assert_called_once() + _, called_kwargs = mock_search_client.search.call_args + assert called_kwargs.get("vector_filter_mode") == "postFilter" + + assert len(results) == 1 + assert results[0].id == "doc2" + assert results[0].score == 0.85 + assert results[0].payload == {"user_id": "user2"} + +def test_delete(azure_ai_search_instance): + instance, mock_search_client, _ = azure_ai_search_instance + vector_id = "doc1" + instance.delete(vector_id) + mock_search_client.delete_documents.assert_called_once_with(documents=[{"id": vector_id}]) + +def test_update(azure_ai_search_instance): + instance, mock_search_client, _ = azure_ai_search_instance + vector_id = "doc1" + new_vector = [0.7, 0.8, 0.9] + new_payload = {"user_id": "updated"} + instance.update(vector_id, vector=new_vector, payload=new_payload) + mock_search_client.merge_or_upload_documents.assert_called_once() + kwargs = mock_search_client.merge_or_upload_documents.call_args.kwargs + document = kwargs["documents"][0] + assert document["id"] == vector_id + assert document["vector"] == new_vector + assert document["payload"] == json.dumps(new_payload) + # The update method will also add the 'user_id' field. + assert document["user_id"] == "updated" + +def test_get(azure_ai_search_instance): + instance, mock_search_client, _ = azure_ai_search_instance + fake_result = { + "id": "doc1", + "payload": json.dumps({"user_id": "user1"}) + } + mock_search_client.get_document.return_value = fake_result + result = instance.get("doc1") + mock_search_client.get_document.assert_called_once_with(key="doc1") + assert result.id == "doc1" + assert result.payload == {"user_id": "user1"} + assert result.score is None + +def test_list(azure_ai_search_instance): + instance, mock_search_client, _ = azure_ai_search_instance + fake_result = { + "id": "doc1", + "@search.score": 0.99, + "payload": json.dumps({"user_id": "user1"}) + } + mock_search_client.search.return_value = [fake_result] + # Call list with a simple filter. + results = instance.list(filters={"user_id": "user1"}, limit=1) + # Verify the search method was called with the proper parameters. + expected_filter = instance._build_filter_expression({"user_id": "user1"}) + mock_search_client.search.assert_called_once_with( + search_text="*", + filter=expected_filter, + top=1 + ) + assert isinstance(results, list) + assert len(results) == 1 + assert results[0].id == "doc1" From 8abc19d8bcc245fc4f73af475a71ed2efc1c14c4 Mon Sep 17 00:00:00 2001 From: Farzad Date: Fri, 7 Feb 2025 08:19:23 -0600 Subject: [PATCH 2/4] azs pr comments --- mem0/vector_stores/azure_ai_search.py | 117 +++++++++++++------- tests/vector_stores/test_azure_ai_search.py | 11 +- 2 files changed, 90 insertions(+), 38 deletions(-) diff --git a/mem0/vector_stores/azure_ai_search.py b/mem0/vector_stores/azure_ai_search.py index 81017ca0ae..d612264381 100644 --- a/mem0/vector_stores/azure_ai_search.py +++ b/mem0/vector_stores/azure_ai_search.py @@ -1,5 +1,6 @@ import json import logging +import re from typing import List, Optional from pydantic import BaseModel @@ -12,9 +13,7 @@ from azure.search.documents import SearchClient from azure.search.documents.indexes import SearchIndexClient from azure.search.documents.indexes.models import ( - BinaryQuantizationCompression, # Added for binary quantization - ) - from azure.search.documents.indexes.models import ( + BinaryQuantizationCompression, HnswAlgorithmConfiguration, ScalarQuantizationCompression, SearchField, @@ -46,8 +45,8 @@ def __init__( collection_name, api_key, embedding_model_dims, - compression_type="none", # "none", "scalar", or "binary" - use_float16=False, + compression_type: Optional[str] = None, + use_float16: bool = False, ): """ Initialize the Azure AI Search vector store. @@ -57,14 +56,16 @@ def __init__( collection_name (str): Index name. api_key (str): API key for the Azure AI Search service. embedding_model_dims (int): Dimension of the embedding vector. - compression_type (str): Specifies the type of quantization to use. - Allowed values are "none", "scalar", or "binary". + compression_type (Optional[str]): Specifies the type of quantization to use. + Allowed values are None (no quantization), "scalar", or "binary". use_float16 (bool): Whether to store vectors in half precision (Edm.Half) or full precision (Edm.Single). + (Note: This flag is preserved from the initial implementation per feedback.) """ self.index_name = collection_name self.collection_name = collection_name self.embedding_model_dims = embedding_model_dims - self.compression_type = compression_type.lower() + # If compression_type is None, treat it as "none". + self.compression_type = (compression_type or "none").lower() self.use_float16 = use_float16 self.search_client = SearchClient( @@ -77,12 +78,8 @@ def __init__( credential=AzureKeyCredential(api_key), ) - # Inject custom UserAgent header ("mem0") for tracking indexes created via mem0. - try: - self.search_client._client._config.user_agent_policy.add_user_agent("mem0") - self.index_client._client._config.user_agent_policy.add_user_agent("mem0") - except Exception as e: - logger.warning(f"Failed to add custom UserAgent header: {e}") + self.search_client._client._config.user_agent_policy.add_user_agent("mem0") + self.index_client._client._config.user_agent_policy.add_user_agent("mem0") self.create_col() # create the collection / index @@ -99,15 +96,28 @@ def create_col(self): compression_name = None if self.compression_type == "scalar": compression_name = "myCompression" + # For SQ, rescoring defaults to True and oversampling defaults to 4. compression_configurations = [ - ScalarQuantizationCompression(compression_name=compression_name) + ScalarQuantizationCompression( + compression_name=compression_name + # rescoring defaults to True and oversampling defaults to 4 + ) ] elif self.compression_type == "binary": compression_name = "myCompression" + # For BQ, rescoring defaults to True and oversampling defaults to 10. compression_configurations = [ - BinaryQuantizationCompression(compression_name=compression_name) + BinaryQuantizationCompression( + compression_name=compression_name + # rescoring defaults to True and oversampling defaults to 10 + ) ] - + # If no compression is desired, compression_configurations remains empty. + + # Note regarding hybrid search: + # FEEDBACK (Discussion): We could store an additional "text" field for hybrid search (keeping the original text searchable) + # but that would change the index design. This is not implemented here to avoid breaking changes. + fields = [ SimpleField(name="id", type=SearchFieldDataType.String, key=True), SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True), @@ -116,7 +126,7 @@ def create_col(self): SearchField( name="vector", type=vector_type, - searchable=True, # Tehnically don't need this on SearchField but leaving for visibility + searchable=True, vector_search_dimensions=self.embedding_model_dims, vector_search_profile_name="my-vector-config", ), @@ -125,7 +135,11 @@ def create_col(self): vector_search = VectorSearch( profiles=[ - VectorSearchProfile(name="my-vector-config", algorithm_configuration_name="my-algorithms-config") + VectorSearchProfile( + name="my-vector-config", + algorithm_configuration_name="my-algorithms-config", + compression_name=compression_name if self.compression_type != "none" else None + ) ], algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")], compressions=compression_configurations, @@ -156,15 +170,24 @@ def insert(self, vectors, payloads=None, ids=None): self._generate_document(vector, payload, id) for id, vector, payload in zip(ids, vectors, payloads) ] - self.search_client.upload_documents(documents) + response = self.search_client.upload_documents(documents) + for doc in response: + if not doc.get("status", False): + raise Exception(f"Insert failed for document {doc.get('id')}: {doc}") + return response + + def _sanitize_key(self, key: str) -> str: + return re.sub(r"[^\w]", "", key) def _build_filter_expression(self, filters): filter_conditions = [] for key, value in filters.items(): + safe_key = self._sanitize_key(key) if isinstance(value, str): - condition = f"{key} eq '{value}'" + safe_value = value.replace("'", "''") + condition = f"{safe_key} eq '{safe_value}'" else: - condition = f"{key} eq {value}" + condition = f"{safe_key} eq {value}" filter_conditions.append(condition) filter_expression = " and ".join(filter_conditions) return filter_expression @@ -181,7 +204,7 @@ def search(self, query, limit=5, filters=None, vector_filter_mode="preFilter"): Known values: "preFilter" (default) and "postFilter". Returns: - list: Search results. + List[OutputData]: Search results. """ filter_expression = None if filters: @@ -190,18 +213,22 @@ def search(self, query, limit=5, filters=None, vector_filter_mode="preFilter"): vector_query = VectorizedQuery( vector=query, k_nearest_neighbors=limit, fields="vector" ) - # Pass vector_filter_mode to the search call. search_results = self.search_client.search( vector_queries=[vector_query], filter=filter_expression, top=limit, - vector_filter_mode=vector_filter_mode # New query parameter for filter mode. + vector_filter_mode=vector_filter_mode, # FEEDBACK 3: New parameter for filter mode. ) results = [] for result in search_results: payload = json.loads(result["payload"]) - results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload)) + results.append( + OutputData( + id=result["id"], score=result["@search.score"], payload=payload + ) + ) + # FEEDBACK 5/6: Return the list of results directly. return results def delete(self, vector_id): @@ -211,8 +238,13 @@ def delete(self, vector_id): Args: vector_id (str): ID of the vector to delete. """ - self.search_client.delete_documents(documents=[{"id": vector_id}]) + response = self.search_client.delete_documents(documents=[{"id": vector_id}]) + # FEEDBACK 6: Check delete response; throw an exception if any deletion failed. + for doc in response: + if not doc.get("status", False): + raise Exception(f"Delete failed for document {vector_id}: {doc}") logger.info(f"Deleted document with ID '{vector_id}' from index '{self.index_name}'.") + return response def update(self, vector_id, vector=None, payload=None): """ @@ -231,7 +263,11 @@ def update(self, vector_id, vector=None, payload=None): document["payload"] = json_payload for field in ["user_id", "run_id", "agent_id"]: document[field] = payload.get(field) - self.search_client.merge_or_upload_documents(documents=[document]) + response = self.search_client.merge_or_upload_documents(documents=[document]) + for doc in response: + if not doc.get("status", False): + raise Exception(f"Update failed for document {vector_id}: {doc}") + return response def get(self, vector_id) -> OutputData: """ @@ -247,7 +283,9 @@ def get(self, vector_id) -> OutputData: result = self.search_client.get_document(key=vector_id) except ResourceNotFoundError: return None - return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"])) + return OutputData( + id=result["id"], score=None, payload=json.loads(result["payload"]) + ) def list_cols(self) -> List[str]: """ @@ -256,8 +294,11 @@ def list_cols(self) -> List[str]: Returns: List[str]: List of index names. """ - indexes = self.index_client.list_indexes() - return [index.name for index in indexes] + try: + names = self.index_client.list_index_names() + except AttributeError: + names = [index.name for index in self.index_client.list_indexes()] + return names def delete_col(self): """Delete the index.""" @@ -268,7 +309,7 @@ def col_info(self): Get information about the index. Returns: - Dict[str, Any]: Index information. + dict: Index information. """ index = self.index_client.get_index(self.index_name) return {"name": index.name, "fields": index.fields} @@ -278,7 +319,7 @@ def list(self, filters=None, limit=100): List all vectors in the index. Args: - filters (Dict, optional): Filters to apply to the list. + filters (dict, optional): Filters to apply to the list. limit (int, optional): Number of vectors to return. Defaults to 100. Returns: @@ -289,14 +330,16 @@ def list(self, filters=None, limit=100): filter_expression = self._build_filter_expression(filters) search_results = self.search_client.search( - search_text="*", - filter=filter_expression, - top=limit + search_text="*", filter=filter_expression, top=limit ) results = [] for result in search_results: payload = json.loads(result["payload"]) - results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload)) + results.append( + OutputData( + id=result["id"], score=result["@search.score"], payload=payload + ) + ) return results def __del__(self): diff --git a/tests/vector_stores/test_azure_ai_search.py b/tests/vector_stores/test_azure_ai_search.py index a7e23662ac..4dfa445ab9 100644 --- a/tests/vector_stores/test_azure_ai_search.py +++ b/tests/vector_stores/test_azure_ai_search.py @@ -17,10 +17,14 @@ def mock_clients(): mock_index_client = MockIndexClient.return_value # Stub required methods on search_client. + # FEEDBACK: Set default return values for methods that are iterated over. mock_search_client.upload_documents = Mock() + mock_search_client.upload_documents.return_value = [{"status": True, "id": "doc1"}] mock_search_client.search = Mock() mock_search_client.delete_documents = Mock() + mock_search_client.delete_documents.return_value = [{"status": True, "id": "doc1"}] mock_search_client.merge_or_upload_documents = Mock() + mock_search_client.merge_or_upload_documents.return_value = [{"status": True, "id": "doc1"}] mock_search_client.get_document = Mock() mock_search_client.close = Mock() @@ -64,6 +68,7 @@ def test_insert(azure_ai_search_instance): payloads = [{"user_id": "user1", "run_id": "run1"}] ids = ["doc1"] + # FEEDBACK: The upload_documents method is now set to return an iterable. instance.insert(vectors, payloads, ids) mock_search_client.upload_documents.assert_called_once() @@ -87,7 +92,7 @@ def test_search_preFilter(azure_ai_search_instance): "@search.score": 0.95, "payload": json.dumps({"user_id": "user1"}) } - # Configure the mock to return an iterator (e.g., a list) with fake_result. + # Configure the mock to return an iterator (list) with fake_result. mock_search_client.search.return_value = [fake_result] query_vector = [0.1, 0.2, 0.3] @@ -129,6 +134,8 @@ def test_search_postFilter(azure_ai_search_instance): def test_delete(azure_ai_search_instance): instance, mock_search_client, _ = azure_ai_search_instance vector_id = "doc1" + # Set delete_documents to return an iterable with a successful response. + mock_search_client.delete_documents.return_value = [{"status": True, "id": vector_id}] instance.delete(vector_id) mock_search_client.delete_documents.assert_called_once_with(documents=[{"id": vector_id}]) @@ -137,6 +144,8 @@ def test_update(azure_ai_search_instance): vector_id = "doc1" new_vector = [0.7, 0.8, 0.9] new_payload = {"user_id": "updated"} + # Set merge_or_upload_documents to return an iterable with a successful response. + mock_search_client.merge_or_upload_documents.return_value = [{"status": True, "id": vector_id}] instance.update(vector_id, vector=new_vector, payload=new_payload) mock_search_client.merge_or_upload_documents.assert_called_once() kwargs = mock_search_client.merge_or_upload_documents.call_args.kwargs From 579a1285575b13cc2ebeda09dd1cb26dcc0e6d04 Mon Sep 17 00:00:00 2001 From: Farzad Date: Fri, 14 Mar 2025 07:56:33 -0500 Subject: [PATCH 3/4] updated config azs --- mem0/configs/vector_stores/azure_ai_search.py | 44 +- tests/vector_stores/test_azure_ai_search.py | 778 +++++++++--------- 2 files changed, 441 insertions(+), 381 deletions(-) diff --git a/mem0/configs/vector_stores/azure_ai_search.py b/mem0/configs/vector_stores/azure_ai_search.py index 5619b3008e..b256e13985 100644 --- a/mem0/configs/vector_stores/azure_ai_search.py +++ b/mem0/configs/vector_stores/azure_ai_search.py @@ -1,27 +1,53 @@ -from typing import Any, Dict - +from typing import Any, Dict, Optional from pydantic import BaseModel, Field, model_validator class AzureAISearchConfig(BaseModel): collection_name: str = Field("mem0", description="Name of the collection") - service_name: str = Field(None, description="Azure Cognitive Search service name") - api_key: str = Field(None, description="API key for the Azure Cognitive Search service") + service_name: str = Field(None, description="Azure AI Search service name") + api_key: str = Field(None, description="API key for the Azure AI Search service") embedding_model_dims: int = Field(None, description="Dimension of the embedding vector") - use_compression: bool = Field(False, description="Whether to use scalar quantization vector compression.") - + compression_type: Optional[str] = Field( + None, + description="Type of vector compression to use. Options: 'scalar', 'binary', or None" + ) + use_float16: bool = Field( + False, + description="Whether to store vectors in half precision (Edm.Half) instead of full precision (Edm.Single)" + ) + @model_validator(mode="before") @classmethod def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: allowed_fields = set(cls.model_fields.keys()) input_fields = set(values.keys()) extra_fields = input_fields - allowed_fields + + # Check for use_compression to provide a helpful error + if "use_compression" in extra_fields: + raise ValueError( + "The parameter 'use_compression' is no longer supported. " + "Please use 'compression_type=\"scalar\"' instead of 'use_compression=True' " + "or 'compression_type=None' instead of 'use_compression=False'." + ) + if extra_fields: raise ValueError( - f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}" + f"Extra fields not allowed: {', '.join(extra_fields)}. " + f"Please input only the following fields: {', '.join(allowed_fields)}" ) + + # Validate compression_type values + if "compression_type" in values and values["compression_type"] is not None: + valid_types = ["scalar", "binary"] + if values["compression_type"].lower() not in valid_types: + raise ValueError( + f"Invalid compression_type: {values['compression_type']}. " + f"Must be one of: {', '.join(valid_types)}, or None" + ) + return values - + model_config = { "arbitrary_types_allowed": True, - } + } \ No newline at end of file diff --git a/tests/vector_stores/test_azure_ai_search.py b/tests/vector_stores/test_azure_ai_search.py index 8d17756eb0..ac1106c530 100644 --- a/tests/vector_stores/test_azure_ai_search.py +++ b/tests/vector_stores/test_azure_ai_search.py @@ -1,20 +1,28 @@ import json -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, MagicMock, call import pytest from azure.core.exceptions import ResourceNotFoundError, HttpResponseError -# Import the AzureAISearch class and OutputData model from your module. -from mem0.vector_stores.azure_ai_search import AzureAISearch +# Import the AzureAISearch class and related models +from mem0.vector_stores.azure_ai_search import AzureAISearch, OutputData +from mem0.configs.vector_stores.azure_ai_search import AzureAISearchConfig # Fixture to patch SearchClient and SearchIndexClient and create an instance of AzureAISearch. @pytest.fixture def mock_clients(): with patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient, \ - patch("mem0.vector_stores.azure_ai_search.SearchIndexClient") as MockIndexClient: + patch("mem0.vector_stores.azure_ai_search.SearchIndexClient") as MockIndexClient, \ + patch("mem0.vector_stores.azure_ai_search.AzureKeyCredential") as MockAzureKeyCredential: # Create mocked instances for search and index clients. mock_search_client = MockSearchClient.return_value mock_index_client = MockIndexClient.return_value + + # Mock the client._client._config.user_agent_policy.add_user_agent + mock_search_client._client = MagicMock() + mock_search_client._client._config.user_agent_policy.add_user_agent = Mock() + mock_index_client._client = MagicMock() + mock_index_client._client._config.user_agent_policy.add_user_agent = Mock() # Stub required methods on search_client. mock_search_client.upload_documents = Mock() @@ -29,7 +37,7 @@ def mock_clients(): # Stub required methods on index_client. mock_index_client.create_or_update_index = Mock() - mock_index_client.list_indexes = Mock(return_value=[]) + mock_index_client.list_indexes = Mock() mock_index_client.list_index_names = Mock(return_value=["test-index"]) mock_index_client.delete_index = Mock() # For col_info() we assume get_index returns an object with name and fields attributes. @@ -39,11 +47,12 @@ def mock_clients(): mock_index_client.get_index = Mock(return_value=fake_index) mock_index_client.close = Mock() - yield mock_search_client, mock_index_client + yield mock_search_client, mock_index_client, MockAzureKeyCredential + @pytest.fixture def azure_ai_search_instance(mock_clients): - mock_search_client, mock_index_client = mock_clients + mock_search_client, mock_index_client, _ = mock_clients # Create an instance with dummy parameters. instance = AzureAISearch( service_name="test-service", @@ -56,150 +65,334 @@ def azure_ai_search_instance(mock_clients): # Return instance and clients for verification. return instance, mock_search_client, mock_index_client -# --- Original tests --- -def test_create_col(azure_ai_search_instance): - instance, mock_search_client, mock_index_client = azure_ai_search_instance - # Upon initialization, create_col should be called. +# --- Tests for AzureAISearchConfig --- + +def test_config_validation_valid(): + """Test valid configurations are accepted.""" + # Test minimal configuration + config = AzureAISearchConfig( + service_name="test-service", + api_key="test-api-key", + embedding_model_dims=768 + ) + assert config.collection_name == "mem0" # Default value + assert config.service_name == "test-service" + assert config.api_key == "test-api-key" + assert config.embedding_model_dims == 768 + assert config.compression_type is None + assert config.use_float16 is False + + # Test with all optional parameters + config = AzureAISearchConfig( + collection_name="custom-index", + service_name="test-service", + api_key="test-api-key", + embedding_model_dims=1536, + compression_type="scalar", + use_float16=True + ) + assert config.collection_name == "custom-index" + assert config.compression_type == "scalar" + assert config.use_float16 is True + + +def test_config_validation_invalid_compression_type(): + """Test that invalid compression types are rejected.""" + with pytest.raises(ValueError) as exc_info: + AzureAISearchConfig( + service_name="test-service", + api_key="test-api-key", + embedding_model_dims=768, + compression_type="invalid-type" # Not a valid option + ) + assert "Invalid compression_type" in str(exc_info.value) + + +def test_config_validation_deprecated_use_compression(): + """Test that using the deprecated use_compression parameter raises an error.""" + with pytest.raises(ValueError) as exc_info: + AzureAISearchConfig( + service_name="test-service", + api_key="test-api-key", + embedding_model_dims=768, + use_compression=True # Deprecated parameter + ) + # Fix: Use a partial string match instead of exact match + assert "use_compression" in str(exc_info.value) + assert "no longer supported" in str(exc_info.value) + + +def test_config_validation_extra_fields(): + """Test that extra fields are rejected.""" + with pytest.raises(ValueError) as exc_info: + AzureAISearchConfig( + service_name="test-service", + api_key="test-api-key", + embedding_model_dims=768, + unknown_parameter="value" # Extra field + ) + assert "Extra fields not allowed" in str(exc_info.value) + assert "unknown_parameter" in str(exc_info.value) + + +# --- Tests for AzureAISearch initialization --- + +def test_initialization(mock_clients): + """Test AzureAISearch initialization with different parameters.""" + mock_search_client, mock_index_client, mock_azure_key_credential = mock_clients + + # Test with minimal parameters + instance = AzureAISearch( + service_name="test-service", + collection_name="test-index", + api_key="test-api-key", + embedding_model_dims=768 + ) + + # Verify initialization parameters + assert instance.index_name == "test-index" + assert instance.collection_name == "test-index" + assert instance.embedding_model_dims == 768 + assert instance.compression_type == "none" # Default when None is passed + assert instance.use_float16 is False + + # Verify client creation + mock_azure_key_credential.assert_called_with("test-api-key") + assert "mem0" in mock_search_client._client._config.user_agent_policy.add_user_agent.call_args[0] + assert "mem0" in mock_index_client._client._config.user_agent_policy.add_user_agent.call_args[0] + + # Verify index creation was called mock_index_client.create_or_update_index.assert_called_once() - # Optionally, you could inspect the call arguments for vector type. -def test_insert(azure_ai_search_instance): - instance, mock_search_client, _ = azure_ai_search_instance - vectors = [[0.1, 0.2, 0.3]] - payloads = [{"user_id": "user1", "run_id": "run1"}] - ids = ["doc1"] - instance.insert(vectors, payloads, ids) +def test_initialization_with_compression_types(mock_clients): + """Test initialization with different compression types.""" + mock_search_client, mock_index_client, _ = mock_clients + + # Test with scalar compression + instance = AzureAISearch( + service_name="test-service", + collection_name="scalar-index", + api_key="test-api-key", + embedding_model_dims=768, + compression_type="scalar" + ) + assert instance.compression_type == "scalar" + + # Capture the index creation call + args, _ = mock_index_client.create_or_update_index.call_args_list[-1] + index = args[0] + # Verify scalar compression was configured + assert hasattr(index.vector_search, 'compressions') + assert len(index.vector_search.compressions) > 0 + assert "ScalarQuantizationCompression" in str(type(index.vector_search.compressions[0])) + + # Test with binary compression + instance = AzureAISearch( + service_name="test-service", + collection_name="binary-index", + api_key="test-api-key", + embedding_model_dims=768, + compression_type="binary" + ) + assert instance.compression_type == "binary" + + # Capture the index creation call + args, _ = mock_index_client.create_or_update_index.call_args_list[-1] + index = args[0] + # Verify binary compression was configured + assert hasattr(index.vector_search, 'compressions') + assert len(index.vector_search.compressions) > 0 + assert "BinaryQuantizationCompression" in str(type(index.vector_search.compressions[0])) + + # Test with no compression + instance = AzureAISearch( + service_name="test-service", + collection_name="no-compression-index", + api_key="test-api-key", + embedding_model_dims=768, + compression_type=None + ) + assert instance.compression_type == "none" + + # Capture the index creation call + args, _ = mock_index_client.create_or_update_index.call_args_list[-1] + index = args[0] + # Verify no compression was configured + assert hasattr(index.vector_search, 'compressions') + assert len(index.vector_search.compressions) == 0 - mock_search_client.upload_documents.assert_called_once() - args, _ = mock_search_client.upload_documents.call_args - documents = args[0] - # Update expected_doc to include extra fields from payload. - expected_doc = { - "id": "doc1", - "vector": [0.1, 0.2, 0.3], - "payload": json.dumps({"user_id": "user1", "run_id": "run1"}), - "user_id": "user1", - "run_id": "run1" - } - assert documents[0] == expected_doc - -def test_search_preFilter(azure_ai_search_instance): - instance, mock_search_client, _ = azure_ai_search_instance - # Setup a fake search result returned by the mocked search method. - fake_result = { - "id": "doc1", - "@search.score": 0.95, - "payload": json.dumps({"user_id": "user1"}) - } - # Configure the mock to return an iterator (list) with fake_result. - mock_search_client.search.return_value = [fake_result] - query_vector = [0.1, 0.2, 0.3] - results = instance.search(query_vector, limit=1, filters={"user_id": "user1"}, vector_filter_mode="preFilter") +def test_initialization_with_float_precision(mock_clients): + """Test initialization with different float precision settings.""" + mock_search_client, mock_index_client, _ = mock_clients + + # Test with half precision (float16) + instance = AzureAISearch( + service_name="test-service", + collection_name="float16-index", + api_key="test-api-key", + embedding_model_dims=768, + use_float16=True + ) + assert instance.use_float16 is True + + # Capture the index creation call + args, _ = mock_index_client.create_or_update_index.call_args_list[-1] + index = args[0] + # Find the vector field and check its type + vector_field = next((f for f in index.fields if f.name == "vector"), None) + assert vector_field is not None + assert "Edm.Half" in vector_field.type + + # Test with full precision (float32) + instance = AzureAISearch( + service_name="test-service", + collection_name="float32-index", + api_key="test-api-key", + embedding_model_dims=768, + use_float16=False + ) + assert instance.use_float16 is False + + # Capture the index creation call + args, _ = mock_index_client.create_or_update_index.call_args_list[-1] + index = args[0] + # Find the vector field and check its type + vector_field = next((f for f in index.fields if f.name == "vector"), None) + assert vector_field is not None + assert "Edm.Single" in vector_field.type - # Verify that the search method was called with vector_filter_mode="preFilter". - mock_search_client.search.assert_called_once() - _, called_kwargs = mock_search_client.search.call_args - assert called_kwargs.get("vector_filter_mode") == "preFilter" - # Verify that the output is parsed correctly. - assert len(results) == 1 - assert results[0].id == "doc1" - assert results[0].score == 0.95 - assert results[0].payload == {"user_id": "user1"} +# --- Tests for create_col method --- -def test_search_postFilter(azure_ai_search_instance): - instance, mock_search_client, _ = azure_ai_search_instance - # Setup a fake search result for postFilter. - fake_result = { - "id": "doc2", - "@search.score": 0.85, - "payload": json.dumps({"user_id": "user2"}) - } - mock_search_client.search.return_value = [fake_result] +def test_create_col(azure_ai_search_instance): + """Test the create_col method creates an index with the correct configuration.""" + instance, _, mock_index_client = azure_ai_search_instance + + # create_col is called during initialization, so we check the call that was already made + mock_index_client.create_or_update_index.assert_called_once() + + # Verify the index configuration + args, _ = mock_index_client.create_or_update_index.call_args + index = args[0] + + # Check basic properties + assert index.name == "test-index" + assert len(index.fields) == 6 # id, user_id, run_id, agent_id, vector, payload + + # Check that required fields are present + field_names = [f.name for f in index.fields] + assert "id" in field_names + assert "vector" in field_names + assert "payload" in field_names + assert "user_id" in field_names + assert "run_id" in field_names + assert "agent_id" in field_names + + # Check that id is the key field + id_field = next(f for f in index.fields if f.name == "id") + assert id_field.key is True + + # Check vector search configuration + assert index.vector_search is not None + assert len(index.vector_search.profiles) == 1 + assert index.vector_search.profiles[0].name == "my-vector-config" + assert index.vector_search.profiles[0].algorithm_configuration_name == "my-algorithms-config" + + # Check algorithms + assert len(index.vector_search.algorithms) == 1 + assert index.vector_search.algorithms[0].name == "my-algorithms-config" + assert "HnswAlgorithmConfiguration" in str(type(index.vector_search.algorithms[0])) + + # With binary compression and float16, we should have compression configuration + assert len(index.vector_search.compressions) == 1 + assert index.vector_search.compressions[0].compression_name == "myCompression" + assert "BinaryQuantizationCompression" in str(type(index.vector_search.compressions[0])) - query_vector = [0.4, 0.5, 0.6] - results = instance.search(query_vector, limit=1, filters={"user_id": "user2"}, vector_filter_mode="postFilter") - mock_search_client.search.assert_called_once() - _, called_kwargs = mock_search_client.search.call_args - assert called_kwargs.get("vector_filter_mode") == "postFilter" +def test_create_col_scalar_compression(mock_clients): + """Test creating a collection with scalar compression.""" + mock_search_client, mock_index_client, _ = mock_clients + + instance = AzureAISearch( + service_name="test-service", + collection_name="scalar-index", + api_key="test-api-key", + embedding_model_dims=768, + compression_type="scalar" + ) + + # Verify the index configuration + args, _ = mock_index_client.create_or_update_index.call_args + index = args[0] + + # Check compression configuration + assert len(index.vector_search.compressions) == 1 + assert index.vector_search.compressions[0].compression_name == "myCompression" + assert "ScalarQuantizationCompression" in str(type(index.vector_search.compressions[0])) + + # Check profile references compression + assert index.vector_search.profiles[0].compression_name == "myCompression" - assert len(results) == 1 - assert results[0].id == "doc2" - assert results[0].score == 0.85 - assert results[0].payload == {"user_id": "user2"} -def test_delete(azure_ai_search_instance): - instance, mock_search_client, _ = azure_ai_search_instance - vector_id = "doc1" - # Set delete_documents to return an iterable with a successful response. - mock_search_client.delete_documents.return_value = [{"status": True, "id": vector_id}] - instance.delete(vector_id) - mock_search_client.delete_documents.assert_called_once_with(documents=[{"id": vector_id}]) +def test_create_col_no_compression(mock_clients): + """Test creating a collection with no compression.""" + mock_search_client, mock_index_client, _ = mock_clients + + instance = AzureAISearch( + service_name="test-service", + collection_name="no-compression-index", + api_key="test-api-key", + embedding_model_dims=768, + compression_type=None + ) + + # Verify the index configuration + args, _ = mock_index_client.create_or_update_index.call_args + index = args[0] + + # Check compression configuration - should be empty + assert len(index.vector_search.compressions) == 0 + + # Check profile doesn't reference compression + assert index.vector_search.profiles[0].compression_name is None -def test_update(azure_ai_search_instance): - instance, mock_search_client, _ = azure_ai_search_instance - vector_id = "doc1" - new_vector = [0.7, 0.8, 0.9] - new_payload = {"user_id": "updated"} - # Set merge_or_upload_documents to return an iterable with a successful response. - mock_search_client.merge_or_upload_documents.return_value = [{"status": True, "id": vector_id}] - instance.update(vector_id, vector=new_vector, payload=new_payload) - mock_search_client.merge_or_upload_documents.assert_called_once() - kwargs = mock_search_client.merge_or_upload_documents.call_args.kwargs - document = kwargs["documents"][0] - assert document["id"] == vector_id - assert document["vector"] == new_vector - assert document["payload"] == json.dumps(new_payload) - # The update method will also add the 'user_id' field. - assert document["user_id"] == "updated" - -def test_get(azure_ai_search_instance): - instance, mock_search_client, _ = azure_ai_search_instance - fake_result = { - "id": "doc1", - "payload": json.dumps({"user_id": "user1"}) - } - mock_search_client.get_document.return_value = fake_result - result = instance.get("doc1") - mock_search_client.get_document.assert_called_once_with(key="doc1") - assert result.id == "doc1" - assert result.payload == {"user_id": "user1"} - assert result.score is None - -def test_list(azure_ai_search_instance): + +# --- Tests for insert method --- + +def test_insert_single(azure_ai_search_instance): + """Test inserting a single vector.""" instance, mock_search_client, _ = azure_ai_search_instance - fake_result = { - "id": "doc1", - "@search.score": 0.99, - "payload": json.dumps({"user_id": "user1"}) - } - mock_search_client.search.return_value = [fake_result] - # Call list with a simple filter. - results = instance.list(filters={"user_id": "user1"}, limit=1) - # Verify the search method was called with the proper parameters. - expected_filter = instance._build_filter_expression({"user_id": "user1"}) - mock_search_client.search.assert_called_once_with( - search_text="*", - filter=expected_filter, - top=1 - ) - assert isinstance(results, list) - assert len(results) == 1 - assert results[0].id == "doc1" + vectors = [[0.1, 0.2, 0.3]] + payloads = [{"user_id": "user1", "run_id": "run1", "agent_id": "agent1"}] + ids = ["doc1"] -# --- New tests for practical end-user scenarios --- + instance.insert(vectors, payloads, ids) -def test_bulk_insert(azure_ai_search_instance): - """Test inserting a batch of documents (common for initial data loading).""" + # Verify upload_documents was called correctly + mock_search_client.upload_documents.assert_called_once() + args, _ = mock_search_client.upload_documents.call_args + documents = args[0] + + # Verify document structure + assert len(documents) == 1 + assert documents[0]["id"] == "doc1" + assert documents[0]["vector"] == [0.1, 0.2, 0.3] + assert documents[0]["payload"] == json.dumps(payloads[0]) + assert documents[0]["user_id"] == "user1" + assert documents[0]["run_id"] == "run1" + assert documents[0]["agent_id"] == "agent1" + + +def test_insert_multiple(azure_ai_search_instance): + """Test inserting multiple vectors in one call.""" instance, mock_search_client, _ = azure_ai_search_instance - # Create a batch of 10 documents - num_docs = 10 - vectors = [[0.1, 0.2, 0.3] for _ in range(num_docs)] + # Create multiple vectors + num_docs = 3 + vectors = [[float(i)/10, float(i+1)/10, float(i+2)/10] for i in range(num_docs)] payloads = [{"user_id": f"user{i}", "content": f"Test content {i}"} for i in range(num_docs)] ids = [f"doc{i}" for i in range(num_docs)] @@ -208,25 +401,35 @@ def test_bulk_insert(azure_ai_search_instance): {"status": True, "id": id_val} for id_val in ids ] - # Insert the batch + # Insert the documents instance.insert(vectors, payloads, ids) - # Verify the call + # Verify upload_documents was called with correct documents mock_search_client.upload_documents.assert_called_once() args, _ = mock_search_client.upload_documents.call_args documents = args[0] + + # Verify all documents were included assert len(documents) == num_docs - # Verify the first and last document + # Check first document assert documents[0]["id"] == "doc0" - assert documents[-1]["id"] == f"doc{num_docs-1}" + assert documents[0]["vector"] == [0.0, 0.1, 0.2] + assert documents[0]["payload"] == json.dumps(payloads[0]) + assert documents[0]["user_id"] == "user0" + + # Check last document + assert documents[2]["id"] == "doc2" + assert documents[2]["vector"] == [0.2, 0.3, 0.4] + assert documents[2]["payload"] == json.dumps(payloads[2]) + assert documents[2]["user_id"] == "user2" -def test_insert_error_handling(azure_ai_search_instance): - """Test how the class handles Azure errors during insertion.""" +def test_insert_with_error(azure_ai_search_instance): + """Test insert when Azure returns an error for one or more documents.""" instance, mock_search_client, _ = azure_ai_search_instance - # Configure mock to return a failure for one document + # Configure mock to return an error for one document mock_search_client.upload_documents.return_value = [ {"status": False, "id": "doc1", "errorMessage": "Azure error"} ] @@ -235,274 +438,105 @@ def test_insert_error_handling(azure_ai_search_instance): payloads = [{"user_id": "user1"}] ids = ["doc1"] - # Exception should be raised + # Insert should raise an exception with pytest.raises(Exception) as exc_info: instance.insert(vectors, payloads, ids) - assert "Insert failed" in str(exc_info.value) - - -def test_search_with_complex_filters(azure_ai_search_instance): - """Test searching with multiple filter conditions as a user might need.""" - instance, mock_search_client, _ = azure_ai_search_instance + assert "Insert failed for document doc1" in str(exc_info.value) - # Configure mock response - mock_search_client.search.return_value = [ - { - "id": "doc1", - "@search.score": 0.95, - "payload": json.dumps({"user_id": "user1", "run_id": "run123", "agent_id": "agent456"}) - } + # Configure mock to return mixed success/failure for multiple documents + mock_search_client.upload_documents.return_value = [ + {"status": True, "id": "doc1"}, + {"status": False, "id": "doc2", "errorMessage": "Azure error"} ] - # Search with multiple filters (common in multi-tenant or segmented applications) - filters = { - "user_id": "user1", - "run_id": "run123", - "agent_id": "agent456" - } - results = instance.search([0.1, 0.2, 0.3], filters=filters) - - # Verify search was called with the correct filter expression - mock_search_client.search.assert_called_once() - _, kwargs = mock_search_client.search.call_args - assert "filter" in kwargs - - # The filter should contain all three conditions - filter_expr = kwargs["filter"] - assert "user_id eq 'user1'" in filter_expr - assert "run_id eq 'run123'" in filter_expr - assert "agent_id eq 'agent456'" in filter_expr - assert " and " in filter_expr # Conditions should be joined by AND - - -def test_empty_search_results(azure_ai_search_instance): - """Test behavior when search returns no results (common edge case).""" - instance, mock_search_client, _ = azure_ai_search_instance - - # Configure mock to return empty results - mock_search_client.search.return_value = [] + vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + payloads = [{"user_id": "user1"}, {"user_id": "user2"}] + ids = ["doc1", "doc2"] - # Search with a non-matching query - results = instance.search([0.9, 0.9, 0.9], limit=5) + # Insert should raise an exception + with pytest.raises(Exception) as exc_info: + instance.insert(vectors, payloads, ids) - # Verify result handling - assert len(results) == 0 + assert "Insert failed for document doc2" in str(exc_info.value) -def test_get_nonexistent_document(azure_ai_search_instance): - """Test behavior when getting a document that doesn't exist (should handle gracefully).""" +def test_insert_with_missing_payload_fields(azure_ai_search_instance): + """Test inserting with payloads missing some of the expected fields.""" instance, mock_search_client, _ = azure_ai_search_instance - - # Configure mock to raise ResourceNotFoundError - mock_search_client.get_document.side_effect = ResourceNotFoundError("Document not found") - - # Get a non-existent document - result = instance.get("nonexistent_id") - - # Should return None instead of raising exception - assert result is None + vectors = [[0.1, 0.2, 0.3]] + payloads = [{"content": "Some content without user_id, run_id, or agent_id"}] + ids = ["doc1"] + instance.insert(vectors, payloads, ids) -def test_azure_service_error(azure_ai_search_instance): - """Test handling of Azure service errors (important for robustness).""" - instance, mock_search_client, _ = azure_ai_search_instance - - # Configure mock to raise HttpResponseError - http_error = HttpResponseError("Azure service is unavailable") - mock_search_client.search.side_effect = http_error - - # Attempt to search - with pytest.raises(HttpResponseError): - instance.search([0.1, 0.2, 0.3]) + # Verify upload_documents was called correctly + mock_search_client.upload_documents.assert_called_once() + args, _ = mock_search_client.upload_documents.call_args + documents = args[0] - # Verify search was attempted - mock_search_client.search.assert_called_once() + # Verify document has payload but not the extra fields + assert len(documents) == 1 + assert documents[0]["id"] == "doc1" + assert documents[0]["vector"] == [0.1, 0.2, 0.3] + assert documents[0]["payload"] == json.dumps(payloads[0]) + assert "user_id" not in documents[0] + assert "run_id" not in documents[0] + assert "agent_id" not in documents[0] -def test_realistic_workflow(azure_ai_search_instance): - """Test a realistic workflow: insert → search → update → search again.""" +def test_insert_with_http_error(azure_ai_search_instance): + """Test insert when Azure client throws an HTTP error.""" instance, mock_search_client, _ = azure_ai_search_instance - # 1. Insert a document - vector = [0.1, 0.2, 0.3] - payload = {"user_id": "user1", "content": "Initial content"} - doc_id = "workflow_doc" + # Configure mock to raise an HttpResponseError + mock_search_client.upload_documents.side_effect = HttpResponseError("Azure service error") - mock_search_client.upload_documents.return_value = [{"status": True, "id": doc_id}] - instance.insert([vector], [payload], [doc_id]) - - # 2. Search for the document - mock_search_client.search.return_value = [ - { - "id": doc_id, - "@search.score": 0.95, - "payload": json.dumps(payload) - } - ] - results = instance.search(vector, filters={"user_id": "user1"}) - assert len(results) == 1 - assert results[0].id == doc_id - - # 3. Update the document - updated_payload = {"user_id": "user1", "content": "Updated content"} - mock_search_client.merge_or_upload_documents.return_value = [{"status": True, "id": doc_id}] - instance.update(doc_id, payload=updated_payload) - - # 4. Search again to get updated document - mock_search_client.search.return_value = [ - { - "id": doc_id, - "@search.score": 0.95, - "payload": json.dumps(updated_payload) - } - ] - results = instance.search(vector, filters={"user_id": "user1"}) - assert len(results) == 1 - assert results[0].id == doc_id - assert results[0].payload["content"] == "Updated content" - - -def test_sanitize_special_characters(azure_ai_search_instance): - """Test that special characters in filter values are properly sanitized.""" - instance, mock_search_client, _ = azure_ai_search_instance - - # Configure mock response - mock_search_client.search.return_value = [ - { - "id": "doc1", - "@search.score": 0.95, - "payload": json.dumps({"user_id": "user's-data"}) - } - ] - - # Search with a filter that has special characters (common in real-world data) - filters = {"user_id": "user's-data"} - results = instance.search([0.1, 0.2, 0.3], filters=filters) + vectors = [[0.1, 0.2, 0.3]] + payloads = [{"user_id": "user1"}] + ids = ["doc1"] - # Verify search was called with properly escaped filter - mock_search_client.search.assert_called_once() - _, kwargs = mock_search_client.search.call_args - assert "filter" in kwargs + # Insert should propagate the HTTP error + with pytest.raises(HttpResponseError) as exc_info: + instance.insert(vectors, payloads, ids) - # The filter should have properly escaped single quotes - filter_expr = kwargs["filter"] - assert "user_id eq 'user''s-data'" in filter_expr - + assert "Azure service error" in str(exc_info.value) -def test_list_collections(azure_ai_search_instance): - """Test listing all collections/indexes (for management interfaces).""" - instance, _, mock_index_client = azure_ai_search_instance - - # List the collections - collections = instance.list_cols() - - # Verify the correct method was called - mock_index_client.list_index_names.assert_called_once() - - # Check the result - assert collections == ["test-index"] +# --- Tests for search method --- -def test_filter_with_numeric_values(azure_ai_search_instance): - """Test filtering with numeric values (common for faceted search).""" +def test_search_basic(azure_ai_search_instance): + """Test basic vector search without filters.""" instance, mock_search_client, _ = azure_ai_search_instance - # Configure mock response + # Configure mock to return search results mock_search_client.search.return_value = [ { "id": "doc1", "@search.score": 0.95, - "payload": json.dumps({"user_id": "user1", "count": 42}) + "payload": json.dumps({"content": "Test content"}) } ] - # Search with a numeric filter - # Note: In the actual implementation, numeric fields might need to be in the payload - filters = {"count": 42} - results = instance.search([0.1, 0.2, 0.3], filters=filters) + # Search with a vector + query_vector = [0.1, 0.2, 0.3] + results = instance.search(query_vector, limit=5) - # Verify the filter expression + # Verify search was called correctly mock_search_client.search.assert_called_once() _, kwargs = mock_search_client.search.call_args - filter_expr = kwargs["filter"] - assert "count eq 42" in filter_expr # No quotes for numbers - - -def test_error_on_update_nonexistent(azure_ai_search_instance): - """Test behavior when updating a document that doesn't exist.""" - instance, mock_search_client, _ = azure_ai_search_instance - - # Configure mock to return a failure for the update - mock_search_client.merge_or_upload_documents.return_value = [ - {"status": False, "id": "nonexistent", "errorMessage": "Document not found"} - ] - # Attempt to update a non-existent document - with pytest.raises(Exception) as exc_info: - instance.update("nonexistent", payload={"new": "data"}) + # Check parameters + assert len(kwargs["vector_queries"]) == 1 + assert kwargs["vector_queries"][0].vector == query_vector + assert kwargs["vector_queries"][0].k_nearest_neighbors == 5 + assert kwargs["vector_queries"][0].fields == "vector" + assert kwargs["filter"] is None # No filters + assert kwargs["top"] == 5 + assert kwargs["vector_filter_mode"] == "preFilter" # Default mode - assert "Update failed" in str(exc_info.value) - - -def test_different_compression_types(): - """Test creating instances with different compression types (important for performance tuning).""" - with patch("mem0.vector_stores.azure_ai_search.SearchClient"), \ - patch("mem0.vector_stores.azure_ai_search.SearchIndexClient"): - - # Test with scalar compression - scalar_instance = AzureAISearch( - service_name="test-service", - collection_name="scalar-index", - api_key="test-api-key", - embedding_model_dims=3, - compression_type="scalar", - use_float16=False - ) - - # Test with no compression - no_compression_instance = AzureAISearch( - service_name="test-service", - collection_name="no-compression-index", - api_key="test-api-key", - embedding_model_dims=3, - compression_type=None, - use_float16=False - ) - - # No assertions needed - we're just verifying that initialization doesn't fail - - -def test_high_dimensional_vectors(): - """Test handling of high-dimensional vectors typical in AI embeddings.""" - with patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient, \ - patch("mem0.vector_stores.azure_ai_search.SearchIndexClient"): - - # Configure the mock client - mock_search_client = MockSearchClient.return_value - mock_search_client.upload_documents = Mock() - mock_search_client.upload_documents.return_value = [{"status": True, "id": "doc1"}] - - # Create an instance with higher dimensions like those from embedding models - high_dim_instance = AzureAISearch( - service_name="test-service", - collection_name="high-dim-index", - api_key="test-api-key", - embedding_model_dims=1536, # Common for models like OpenAI's embeddings - compression_type="binary", # Compression often used with high-dim vectors - use_float16=True # Reduced precision often used for memory efficiency - ) - - # Create a high-dimensional vector (stub with zeros for testing) - high_dim_vector = [0.0] * 1536 - payload = {"user_id": "user1"} - doc_id = "high_dim_doc" - - # Insert the document - high_dim_instance.insert([high_dim_vector], [payload], [doc_id]) - - # Verify the insert was called with the full vector - mock_search_client.upload_documents.assert_called_once() - args, _ = mock_search_client.upload_documents.call_args - documents = args[0] - assert len(documents[0]["vector"]) == 1536 + # Check results + assert len(results) == 1 + assert results[0].id == "doc1" + assert results[0].score == 0.95 + assert results[0].payload == {"content": "Test content"} \ No newline at end of file From 54fffeaf05771b8271ee70b6e99c961d3bbd4d30 Mon Sep 17 00:00:00 2001 From: Farzad Date: Mon, 17 Mar 2025 11:34:59 -0500 Subject: [PATCH 4/4] revert lock changes --- poetry.lock | 47 ----------------------------------------------- 1 file changed, 47 deletions(-) diff --git a/poetry.lock b/poetry.lock index 82e4ae6045..96c789b242 100644 --- a/poetry.lock +++ b/poetry.lock @@ -241,53 +241,6 @@ azure-core = ">=1.28.0" isodate = ">=0.6.0" typing-extensions = ">=4.6.0" -[[package]] -name = "azure-common" -version = "1.1.28" -description = "Microsoft Azure Client Library for Python (Common)" -optional = false -python-versions = "*" -files = [ - {file = "azure-common-1.1.28.zip", hash = "sha256:4ac0cd3214e36b6a1b6a442686722a5d8cc449603aa833f3f0f40bda836704a3"}, - {file = "azure_common-1.1.28-py2.py3-none-any.whl", hash = "sha256:5c12d3dcf4ec20599ca6b0d3e09e86e146353d443e7fcc050c9a19c1f9df20ad"}, -] - -[[package]] -name = "azure-core" -version = "1.32.0" -description = "Microsoft Azure Core Library for Python" -optional = false -python-versions = ">=3.8" -files = [ - {file = "azure_core-1.32.0-py3-none-any.whl", hash = "sha256:eac191a0efb23bfa83fddf321b27b122b4ec847befa3091fa736a5c32c50d7b4"}, - {file = "azure_core-1.32.0.tar.gz", hash = "sha256:22b3c35d6b2dae14990f6c1be2912bf23ffe50b220e708a28ab1bb92b1c730e5"}, -] - -[package.dependencies] -requests = ">=2.21.0" -six = ">=1.11.0" -typing-extensions = ">=4.6.0" - -[package.extras] -aio = ["aiohttp (>=3.0)"] - -[[package]] -name = "azure-search-documents" -version = "11.5.2" -description = "Microsoft Azure Cognitive Search Client Library for Python" -optional = false -python-versions = ">=3.8" -files = [ - {file = "azure_search_documents-11.5.2-py3-none-any.whl", hash = "sha256:c949d011008a4b0bcee3db91132741b4e4d50ddb3f7e2f48944d949d4b413b11"}, - {file = "azure_search_documents-11.5.2.tar.gz", hash = "sha256:98977dd1fa4978d3b7d8891a0856b3becb6f02cc07ff2e1ea40b9c7254ada315"}, -] - -[package.dependencies] -azure-common = ">=1.1" -azure-core = ">=1.28.0" -isodate = ">=0.6.0" -typing-extensions = ">=4.6.0" - [[package]] name = "backoff" version = "2.2.1"