diff --git a/integrations/chroma/README.md b/integrations/chroma/README.md index fadfc4c09..c8b009ff9 100644 --- a/integrations/chroma/README.md +++ b/integrations/chroma/README.md @@ -11,3 +11,6 @@ ## Contributing Refer to the general [Contribution Guidelines](https://github.com/deepset-ai/haystack-core-integrations/blob/main/CONTRIBUTING.md). + +To run integration tests locally, you need a Chroma server running. +Start one with: `docker run -p 8000:8000 chromadb/chroma:latest`. diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py index 871020f0d..2d0633bde 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py @@ -8,6 +8,7 @@ import chromadb from chromadb.api.models.AsyncCollection import AsyncCollection from chromadb.api.types import GetResult, Metadata, OneOrMany, QueryResult +from chromadb.config import Settings from haystack import default_from_dict, default_to_dict, logging from haystack.dataclasses import Document from haystack.document_stores.errors import DocumentStoreError @@ -40,6 +41,7 @@ def __init__( port: Optional[int] = None, distance_function: Literal["l2", "cosine", "ip"] = "l2", metadata: Optional[dict] = None, + client_settings: Optional[dict[str, Any]] = None, **embedding_function_params: Any, ): """ @@ -67,6 +69,11 @@ def __init__( :param metadata: a dictionary of chromadb collection parameters passed directly to chromadb's client method `create_collection`. If it contains the key `"hnsw:space"`, the value will take precedence over the `distance_function` parameter above. + :param client_settings: a dictionary of Chroma Settings configuration options passed to + `chromadb.config.Settings`. These settings configure the underlying Chroma client behavior. + For available options, see [Chroma's config.py](https://github.com/chroma-core/chroma/blob/main/chromadb/config.py). + **Note**: specifying these settings may interfere with standard client initialization parameters. + This option is intended for advanced customization. :param embedding_function_params: additional parameters to pass to the embedding function. """ @@ -84,6 +91,7 @@ def __init__( self._embedding_function_params = embedding_function_params self._distance_function = distance_function self._metadata = metadata + self._client_settings = client_settings self._persist_path = persist_path self._host = host @@ -102,18 +110,29 @@ def _ensure_initialized(self): "You cannot specify both options." ) raise ValueError(error_message) + + # Use dict to conditionally pass settings because Chroma doesn't accept settings=None + client_kwargs: dict[str, Any] = {} + if self._client_settings: + try: + client_kwargs["settings"] = Settings(**self._client_settings) + except ValueError as e: + msg = f"Invalid client_settings ({self._client_settings}): {e}" + raise ValueError(msg) from e + if self._host and self._port is not None: # Remote connection via HTTP client client = chromadb.HttpClient( host=self._host, port=self._port, + **client_kwargs, ) elif self._persist_path is None: # In-memory storage - client = chromadb.Client() + client = chromadb.Client(**client_kwargs) else: # Local persistent storage - client = chromadb.PersistentClient(path=self._persist_path) + client = chromadb.PersistentClient(path=self._persist_path, **client_kwargs) self._client = client # store client for potential future use @@ -148,9 +167,19 @@ async def _ensure_initialized_async(self): ) raise ValueError(error_message) + # Use dict to conditionally pass settings because Chroma doesn't accept settings=None + client_kwargs: dict[str, Any] = {} + if self._client_settings: + try: + client_kwargs["settings"] = Settings(**self._client_settings) + except ValueError as e: + msg = f"Invalid client_settings ({self._client_settings}): {e}" + raise ValueError(msg) from e + client = await chromadb.AsyncHttpClient( host=self._host, port=self._port, + **client_kwargs, ) self._async_client = client # store client for potential future use @@ -862,6 +891,7 @@ def to_dict(self) -> dict[str, Any]: host=self._host, port=self._port, distance_function=self._distance_function, + client_settings=self._client_settings, **self._embedding_function_params, ) diff --git a/integrations/chroma/tests/test_document_store.py b/integrations/chroma/tests/test_document_store.py index 0d6e39ad0..ac457567a 100644 --- a/integrations/chroma/tests/test_document_store.py +++ b/integrations/chroma/tests/test_document_store.py @@ -9,6 +9,7 @@ from unittest import mock import pytest +from chromadb.api.shared_system_client import SharedSystemClient from haystack.dataclasses import ByteStream, Document from haystack.testing.document_store import ( TEST_EMBEDDING_1, @@ -20,6 +21,19 @@ from haystack_integrations.document_stores.chroma import ChromaDocumentStore +@pytest.fixture +def clear_chroma_system_cache(): + """ + Chroma's in-memory client uses a singleton pattern with an internal cache. + Once a client is created with certain settings, Chroma rejects creating another + with different settings in the same process. This fixture clears the cache + before and after tests that use custom client settings. + """ + SharedSystemClient.clear_system_cache() + yield + SharedSystemClient.clear_system_cache() + + class TestDocumentStore(CountDocumentsTest, DeleteDocumentsTest, FilterDocumentsTest): """ Common test cases will be provided by `DocumentStoreBaseTests` but @@ -75,6 +89,10 @@ def test_init_http_connection(self): assert store._host == "localhost" assert store._port == 8000 + def test_init_with_client_settings(self): + store = ChromaDocumentStore(client_settings={"anonymized_telemetry": False}) + assert store._client_settings == {"anonymized_telemetry": False} + def test_invalid_initialization_both_host_and_persist_path(self): """ Test that providing both host and persist_path raises an error. @@ -83,9 +101,33 @@ def test_invalid_initialization_both_host_and_persist_path(self): store = ChromaDocumentStore(persist_path="./path/to/local/store", host="localhost") store._ensure_initialized() + def test_client_settings_applied(self, clear_chroma_system_cache): + """ + Chroma's in-memory client uses a singleton pattern with an internal cache. + Once a client is created with certain settings, Chroma rejects creating another + with different settings in the same process. We clear the cache before and after + this test to avoid conflicts with other tests that use default settings. + """ + store = ChromaDocumentStore(client_settings={"anonymized_telemetry": False}) + store._ensure_initialized() + assert store._client.get_settings().anonymized_telemetry is False + + def test_invalid_client_settings(self, clear_chroma_system_cache): + store = ChromaDocumentStore( + client_settings={ + "invalid_setting_name": "some_value", + "another_fake_setting": 123, + } + ) + with pytest.raises(ValueError, match="Invalid client_settings"): + store._ensure_initialized() + def test_to_dict(self, request): ds = ChromaDocumentStore( - collection_name=request.node.name, embedding_function="HuggingFaceEmbeddingFunction", api_key="1234567890" + collection_name=request.node.name, + embedding_function="HuggingFaceEmbeddingFunction", + api_key="1234567890", + client_settings={"anonymized_telemetry": False}, ) ds_dict = ds.to_dict() assert ds_dict == { @@ -98,6 +140,7 @@ def test_to_dict(self, request): "port": None, "api_key": "1234567890", "distance_function": "l2", + "client_settings": {"anonymized_telemetry": False}, }, } @@ -114,6 +157,7 @@ def test_from_dict(self): "port": None, "api_key": "1234567890", "distance_function": "l2", + "client_settings": {"anonymized_telemetry": False}, }, } @@ -121,6 +165,7 @@ def test_from_dict(self): assert ds._collection_name == collection_name assert ds._embedding_function == function_name assert ds._embedding_function_params == {"api_key": "1234567890"} + assert ds._client_settings == {"anonymized_telemetry": False} def test_same_collection_name_reinitialization(self): ChromaDocumentStore("test_1") diff --git a/integrations/chroma/tests/test_document_store_async.py b/integrations/chroma/tests/test_document_store_async.py index b09ad1ebe..c0fc9771a 100644 --- a/integrations/chroma/tests/test_document_store_async.py +++ b/integrations/chroma/tests/test_document_store_async.py @@ -18,6 +18,7 @@ sys.platform == "win32", reason="We do not run the Chroma server on Windows and async is only supported with HTTP connections", ) +@pytest.mark.integration @pytest.mark.asyncio class TestDocumentStoreAsync: @pytest.fixture @@ -96,7 +97,29 @@ async def test_comparison_equal_async(self, document_store, filterable_docs): ) self.assert_documents_are_equal(result, [d for d in filterable_docs if d.meta.get("number") == 100]) - @pytest.mark.integration + async def test_client_settings_applied_async(self): + store = ChromaDocumentStore( + host="localhost", + port=8000, + client_settings={"anonymized_telemetry": False}, + collection_name=f"{uuid.uuid1()}-async-settings", + ) + await store._ensure_initialized_async() + assert store._async_client.get_settings().anonymized_telemetry is False + + async def test_invalid_client_settings_async(self): + store = ChromaDocumentStore( + host="localhost", + port=8000, + client_settings={ + "invalid_setting_name": "some_value", + "another_fake_setting": 123, + }, + collection_name=f"{uuid.uuid1()}-async-invalid", + ) + with pytest.raises(ValueError, match="Invalid client_settings"): + await store._ensure_initialized_async() + async def test_search_async(self): document_store = ChromaDocumentStore(host="localhost", port=8000, collection_name="my_custom_collection") diff --git a/integrations/chroma/tests/test_retriever.py b/integrations/chroma/tests/test_retriever.py index 23c7b0bdc..a3cf71e88 100644 --- a/integrations/chroma/tests/test_retriever.py +++ b/integrations/chroma/tests/test_retriever.py @@ -41,6 +41,7 @@ def test_to_dict(self, request): "port": None, "api_key": "1234567890", "distance_function": "l2", + "client_settings": None, }, }, }, @@ -131,6 +132,7 @@ def test_to_dict(self, request): "port": None, "api_key": "1234567890", "distance_function": "l2", + "client_settings": None, }, }, },