Skip to content
Merged
3 changes: 3 additions & 0 deletions integrations/chroma/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@
## Contributing

Refer to the general [Contribution Guidelines](https://github.com/deepset-ai/haystack-core-integrations/blob/main/CONTRIBUTING.md).

To run integration tests locally, you need a Chroma server running.
Start one with: `docker run -p 8000:8000 chromadb/chroma:latest`.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import chromadb
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.types import GetResult, Metadata, OneOrMany, QueryResult
from chromadb.config import Settings
from haystack import default_from_dict, default_to_dict, logging
from haystack.dataclasses import Document
from haystack.document_stores.errors import DocumentStoreError
Expand Down Expand Up @@ -40,6 +41,7 @@ def __init__(
port: Optional[int] = None,
distance_function: Literal["l2", "cosine", "ip"] = "l2",
metadata: Optional[dict] = None,
client_settings: Optional[dict[str, Any]] = None,
**embedding_function_params: Any,
):
"""
Expand Down Expand Up @@ -67,6 +69,11 @@ def __init__(
:param metadata: a dictionary of chromadb collection parameters passed directly to chromadb's client
method `create_collection`. If it contains the key `"hnsw:space"`, the value will take precedence over the
`distance_function` parameter above.
:param client_settings: a dictionary of Chroma Settings configuration options passed to
`chromadb.config.Settings`. These settings configure the underlying Chroma client behavior.
For available options, see [Chroma's config.py](https://github.com/chroma-core/chroma/blob/main/chromadb/config.py).
**Note**: specifying these settings may interfere with standard client initialization parameters.
This option is intended for advanced customization.
:param embedding_function_params: additional parameters to pass to the embedding function.
"""

Expand All @@ -84,6 +91,7 @@ def __init__(
self._embedding_function_params = embedding_function_params
self._distance_function = distance_function
self._metadata = metadata
self._client_settings = client_settings

self._persist_path = persist_path
self._host = host
Expand All @@ -102,18 +110,29 @@ def _ensure_initialized(self):
"You cannot specify both options."
)
raise ValueError(error_message)

# Use dict to conditionally pass settings because Chroma doesn't accept settings=None
client_kwargs: dict[str, Any] = {}
if self._client_settings:
try:
client_kwargs["settings"] = Settings(**self._client_settings)
except ValueError as e:
msg = f"Invalid client_settings ({self._client_settings}): {e}"
raise ValueError(msg) from e

if self._host and self._port is not None:
# Remote connection via HTTP client
client = chromadb.HttpClient(
host=self._host,
port=self._port,
**client_kwargs,
)
elif self._persist_path is None:
# In-memory storage
client = chromadb.Client()
client = chromadb.Client(**client_kwargs)
else:
# Local persistent storage
client = chromadb.PersistentClient(path=self._persist_path)
client = chromadb.PersistentClient(path=self._persist_path, **client_kwargs)

self._client = client # store client for potential future use

Expand Down Expand Up @@ -148,9 +167,19 @@ async def _ensure_initialized_async(self):
)
raise ValueError(error_message)

# Use dict to conditionally pass settings because Chroma doesn't accept settings=None
client_kwargs: dict[str, Any] = {}
if self._client_settings:
try:
client_kwargs["settings"] = Settings(**self._client_settings)
except ValueError as e:
msg = f"Invalid client_settings ({self._client_settings}): {e}"
raise ValueError(msg) from e

client = await chromadb.AsyncHttpClient(
host=self._host,
port=self._port,
**client_kwargs,
)

self._async_client = client # store client for potential future use
Expand Down Expand Up @@ -862,6 +891,7 @@ def to_dict(self) -> dict[str, Any]:
host=self._host,
port=self._port,
distance_function=self._distance_function,
client_settings=self._client_settings,
**self._embedding_function_params,
)

Expand Down
47 changes: 46 additions & 1 deletion integrations/chroma/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from unittest import mock

import pytest
from chromadb.api.shared_system_client import SharedSystemClient
from haystack.dataclasses import ByteStream, Document
from haystack.testing.document_store import (
TEST_EMBEDDING_1,
Expand All @@ -20,6 +21,19 @@
from haystack_integrations.document_stores.chroma import ChromaDocumentStore


@pytest.fixture
def clear_chroma_system_cache():
"""
Chroma's in-memory client uses a singleton pattern with an internal cache.
Once a client is created with certain settings, Chroma rejects creating another
with different settings in the same process. This fixture clears the cache
before and after tests that use custom client settings.
"""
SharedSystemClient.clear_system_cache()
yield
SharedSystemClient.clear_system_cache()


class TestDocumentStore(CountDocumentsTest, DeleteDocumentsTest, FilterDocumentsTest):
"""
Common test cases will be provided by `DocumentStoreBaseTests` but
Expand Down Expand Up @@ -75,6 +89,10 @@ def test_init_http_connection(self):
assert store._host == "localhost"
assert store._port == 8000

def test_init_with_client_settings(self):
store = ChromaDocumentStore(client_settings={"anonymized_telemetry": False})
assert store._client_settings == {"anonymized_telemetry": False}

def test_invalid_initialization_both_host_and_persist_path(self):
"""
Test that providing both host and persist_path raises an error.
Expand All @@ -83,9 +101,33 @@ def test_invalid_initialization_both_host_and_persist_path(self):
store = ChromaDocumentStore(persist_path="./path/to/local/store", host="localhost")
store._ensure_initialized()

def test_client_settings_applied(self, clear_chroma_system_cache):
"""
Chroma's in-memory client uses a singleton pattern with an internal cache.
Once a client is created with certain settings, Chroma rejects creating another
with different settings in the same process. We clear the cache before and after
this test to avoid conflicts with other tests that use default settings.
"""
store = ChromaDocumentStore(client_settings={"anonymized_telemetry": False})
store._ensure_initialized()
assert store._client.get_settings().anonymized_telemetry is False

def test_invalid_client_settings(self, clear_chroma_system_cache):
store = ChromaDocumentStore(
client_settings={
"invalid_setting_name": "some_value",
"another_fake_setting": 123,
}
)
with pytest.raises(ValueError, match="Invalid client_settings"):
store._ensure_initialized()

def test_to_dict(self, request):
ds = ChromaDocumentStore(
collection_name=request.node.name, embedding_function="HuggingFaceEmbeddingFunction", api_key="1234567890"
collection_name=request.node.name,
embedding_function="HuggingFaceEmbeddingFunction",
api_key="1234567890",
client_settings={"anonymized_telemetry": False},
)
ds_dict = ds.to_dict()
assert ds_dict == {
Expand All @@ -98,6 +140,7 @@ def test_to_dict(self, request):
"port": None,
"api_key": "1234567890",
"distance_function": "l2",
"client_settings": {"anonymized_telemetry": False},
},
}

Expand All @@ -114,13 +157,15 @@ def test_from_dict(self):
"port": None,
"api_key": "1234567890",
"distance_function": "l2",
"client_settings": {"anonymized_telemetry": False},
},
}

ds = ChromaDocumentStore.from_dict(ds_dict)
assert ds._collection_name == collection_name
assert ds._embedding_function == function_name
assert ds._embedding_function_params == {"api_key": "1234567890"}
assert ds._client_settings == {"anonymized_telemetry": False}

def test_same_collection_name_reinitialization(self):
ChromaDocumentStore("test_1")
Expand Down
25 changes: 24 additions & 1 deletion integrations/chroma/tests/test_document_store_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
sys.platform == "win32",
reason="We do not run the Chroma server on Windows and async is only supported with HTTP connections",
)
@pytest.mark.integration
@pytest.mark.asyncio
class TestDocumentStoreAsync:
@pytest.fixture
Expand Down Expand Up @@ -96,7 +97,29 @@ async def test_comparison_equal_async(self, document_store, filterable_docs):
)
self.assert_documents_are_equal(result, [d for d in filterable_docs if d.meta.get("number") == 100])

@pytest.mark.integration
async def test_client_settings_applied_async(self):
store = ChromaDocumentStore(
host="localhost",
port=8000,
client_settings={"anonymized_telemetry": False},
collection_name=f"{uuid.uuid1()}-async-settings",
)
await store._ensure_initialized_async()
assert store._async_client.get_settings().anonymized_telemetry is False

async def test_invalid_client_settings_async(self):
store = ChromaDocumentStore(
host="localhost",
port=8000,
client_settings={
"invalid_setting_name": "some_value",
"another_fake_setting": 123,
},
collection_name=f"{uuid.uuid1()}-async-invalid",
)
with pytest.raises(ValueError, match="Invalid client_settings"):
await store._ensure_initialized_async()

async def test_search_async(self):
document_store = ChromaDocumentStore(host="localhost", port=8000, collection_name="my_custom_collection")

Expand Down
2 changes: 2 additions & 0 deletions integrations/chroma/tests/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_to_dict(self, request):
"port": None,
"api_key": "1234567890",
"distance_function": "l2",
"client_settings": None,
},
},
},
Expand Down Expand Up @@ -131,6 +132,7 @@ def test_to_dict(self, request):
"port": None,
"api_key": "1234567890",
"distance_function": "l2",
"client_settings": None,
},
},
},
Expand Down