Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import chromadb
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.types import GetResult, QueryResult
from chromadb.api.types import GetResult, Metadata, OneOrMany, QueryResult
from haystack import default_from_dict, default_to_dict, logging
from haystack.dataclasses import Document
from haystack.document_stores.errors import DocumentStoreError
Expand Down Expand Up @@ -178,7 +178,8 @@ async def _ensure_initialized_async(self):
embedding_function=self._embedding_func,
)

def _prepare_get_kwargs(self, filters: Optional[dict[str, Any]] = None) -> dict[str, Any]:
@staticmethod
def _prepare_get_kwargs(filters: Optional[dict[str, Any]] = None) -> dict[str, Any]:
"""
Prepare kwargs for Chroma get operations.
"""
Expand All @@ -195,7 +196,8 @@ def _prepare_get_kwargs(self, filters: Optional[dict[str, Any]] = None) -> dict[

return kwargs

def _prepare_query_kwargs(self, filters: Optional[dict[str, Any]] = None) -> dict[str, Any]:
@staticmethod
def _prepare_query_kwargs(filters: Optional[dict[str, Any]] = None) -> dict[str, Any]:
"""
Prepare kwargs for Chroma query operations.
"""
Expand Down Expand Up @@ -246,7 +248,7 @@ def filter_documents(self, filters: Optional[dict[str, Any]] = None) -> list[Doc
self._ensure_initialized()
assert self._collection is not None

kwargs = self._prepare_get_kwargs(filters)
kwargs = ChromaDocumentStore._prepare_get_kwargs(filters)
result = self._collection.get(**kwargs)

return self._get_result_to_documents(result)
Expand All @@ -266,12 +268,63 @@ async def filter_documents_async(self, filters: Optional[dict[str, Any]] = None)
await self._ensure_initialized_async()
assert self._async_collection is not None

kwargs = self._prepare_get_kwargs(filters)
kwargs = ChromaDocumentStore._prepare_get_kwargs(filters)
result = await self._async_collection.get(**kwargs)

return self._get_result_to_documents(result)

def _convert_document_to_chroma(self, doc: Document) -> Optional[dict[str, Any]]:
@staticmethod
def _filter_metadata(meta: dict[str, Any]) -> dict[str, str | int | float | bool | None]:
"""
Filters metadata to only include supported types for Chroma.

returns:
A new dictionary with only valid metadata values.
"""
valid_meta: dict[str, str | int | float | bool | None] = {}
discarded_keys = []

for k, v in meta.items():
if v is None or isinstance(v, SUPPORTED_TYPES_FOR_METADATA_VALUES):
valid_meta[k] = v
else:
discarded_keys.append(k)

if discarded_keys:
logger.warning(
"Metadata contains values of unsupported types for the keys: {keys}. "
"These items will be discarded. Supported types are: {types}.",
keys=", ".join(discarded_keys),
types=", ".join([t.__name__ for t in SUPPORTED_TYPES_FOR_METADATA_VALUES]),
)

return valid_meta

@staticmethod
def _prepare_metadata_update(
matching_docs: list[Document], meta: dict[str, Any]
) -> tuple[list[str], list[Metadata]]:
"""
Prepares document IDs and updated metadata for batch update operations.

:param matching_docs: List of documents to update.
:param meta: New metadata to merge with existing document metadata.
:returns: Tuple of (ids_to_update, updated_metadata).
"""
ids_to_update = []
updated_metadata: list[Metadata] = []

for doc in matching_docs:
ids_to_update.append(doc.id)
current_meta = doc.meta or {}
updated_meta = {**current_meta, **meta}
filtered_meta = ChromaDocumentStore._filter_metadata(updated_meta)
updated_metadata.append(cast(Metadata, filtered_meta))

return ids_to_update, updated_metadata

@staticmethod
def _convert_document_to_chroma(doc: Document) -> Optional[dict[str, Any]]:
"""
Converts a Haystack Document to a Chroma document.
"""
Expand Down Expand Up @@ -353,7 +406,7 @@ def write_documents(
assert self._collection is not None

for doc in documents:
data = self._convert_document_to_chroma(doc)
data = ChromaDocumentStore._convert_document_to_chroma(doc)
if data is not None:
self._collection.add(**data)

Expand Down Expand Up @@ -384,7 +437,7 @@ async def write_documents_async(
assert self._async_collection is not None

for doc in documents:
data = self._convert_document_to_chroma(doc)
data = ChromaDocumentStore._convert_document_to_chroma(doc)
if data is not None:
await self._async_collection.add(**data)

Expand Down Expand Up @@ -414,6 +467,181 @@ async def delete_documents_async(self, document_ids: list[str]) -> None:

await self._async_collection.delete(ids=document_ids)

def delete_by_filter(self, filters: dict[str, Any]) -> int:
"""
Deletes all documents that match the provided filters.

:param filters: The filters to apply to select documents for deletion.
For filter syntax, see [Haystack metadata filtering](https://docs.haystack.deepset.ai/v2.0/docs/metadata-filtering)
:returns: The number of documents deleted.
"""
self._ensure_initialized()
assert self._collection is not None

try:
chroma_filter = _convert_filters(filters)

# count documents before deletion since ChromaDB doesn't return count
matching_docs = self.filter_documents(filters)
count = len(matching_docs)

if count == 0:
return 0

delete_kwargs: dict[str, Any] = {}

if chroma_filter.ids:
# if the filter contains IDs, use them directly
delete_kwargs["ids"] = chroma_filter.ids
else:
# use where/where_document filters
if chroma_filter.where:
delete_kwargs["where"] = chroma_filter.where
if chroma_filter.where_document:
delete_kwargs["where_document"] = chroma_filter.where_document

# perform deletion
self._collection.delete(**delete_kwargs)

logger.info(
"Deleted {n_docs} documents from collection '{name}' using filters.",
n_docs=count,
name=self._collection_name,
)
return count
except Exception as e:
msg = f"Failed to delete documents by filter from ChromaDB: {e!s}"
raise DocumentStoreError(msg) from e

async def delete_by_filter_async(self, filters: dict[str, Any]) -> int:
"""
Asynchronously deletes all documents that match the provided filters.

Asynchronous methods are only supported for HTTP connections.

:param filters: The filters to apply to select documents for deletion.
For filter syntax, see [Haystack metadata filtering](https://docs.haystack.deepset.ai/v2.0/docs/metadata-filtering)
:returns: The number of documents deleted.
"""
await self._ensure_initialized_async()
assert self._async_collection is not None

try:
chroma_filter = _convert_filters(filters)

# count documents before deletion since ChromaDB doesn't return count
matching_docs = await self.filter_documents_async(filters)
count = len(matching_docs)

if count == 0:
return 0

delete_kwargs: dict[str, Any] = {}

if chroma_filter.ids:
# if filter contains IDs, use them directly
delete_kwargs["ids"] = chroma_filter.ids
else:
# use where/where_document filters
if chroma_filter.where:
delete_kwargs["where"] = chroma_filter.where
if chroma_filter.where_document:
delete_kwargs["where_document"] = chroma_filter.where_document

await self._async_collection.delete(**delete_kwargs)

logger.info(
"Deleted {n_docs} documents from collection '{name}' using filters.",
n_docs=count,
name=self._collection_name,
)
return count
except Exception as e:
msg = f"Failed to delete documents by filter from ChromaDB: {e!s}"
raise DocumentStoreError(msg) from e

def update_by_filter(self, filters: dict[str, Any], meta: dict[str, Any]) -> int:
"""
Updates the metadata of all documents that match the provided filters.

**Note**: This operation is not atomic. Documents matching the filter are fetched first,
then updated. If documents are modified between the fetch and update operations,
those changes may be lost.

:param filters: The filters to apply to select documents for updating.
For filter syntax, see [Haystack metadata filtering](https://docs.haystack.deepset.ai/v2.0/docs/metadata-filtering)
:param meta: The metadata fields to update. This will be merged with existing metadata.
:returns: The number of documents updated.
"""
self._ensure_initialized()
assert self._collection is not None

try:
matching_docs = self.filter_documents(filters)

if not matching_docs:
return 0

ids_to_update, updated_metadata = ChromaDocumentStore._prepare_metadata_update(matching_docs, meta)

# batch update
self._collection.update(
ids=ids_to_update,
metadatas=cast(OneOrMany[Metadata], updated_metadata),
)

logger.info(
"Updated {n_docs} documents in collection '{name}' using filters.",
n_docs=len(ids_to_update),
name=self._collection_name,
)
return len(ids_to_update)
except Exception as e:
msg = f"Failed to update documents by filter in ChromaDB: {e!s}"
raise DocumentStoreError(msg) from e

async def update_by_filter_async(self, filters: dict[str, Any], meta: dict[str, Any]) -> int:
"""
Asynchronously updates the metadata of all documents that match the provided filters.

Asynchronous methods are only supported for HTTP connections.

**Note**: This operation is not atomic. Documents matching the filter are fetched first,
then updated. If documents are modified between the fetch and update operations,
those changes may be lost.

:param filters: The filters to apply to select documents for updating.
For filter syntax, see [Haystack metadata filtering](https://docs.haystack.deepset.ai/v2.0/docs/metadata-filtering)
:param meta: The metadata fields to update. This will be merged with existing metadata.
:returns: The number of documents updated.
"""
await self._ensure_initialized_async()
assert self._async_collection is not None

try:
matching_docs = await self.filter_documents_async(filters)

if not matching_docs:
return 0

ids_to_update, updated_metadata = ChromaDocumentStore._prepare_metadata_update(matching_docs, meta)

# batch update
await self._async_collection.update(
ids=ids_to_update,
metadatas=cast(OneOrMany[Metadata], updated_metadata),
)

logger.info(
"Updated {n_docs} documents in collection '{name}' using filters.",
n_docs=len(ids_to_update),
name=self._collection_name,
)
return len(ids_to_update)
except Exception as e:
msg = f"Failed to update documents by filter in ChromaDB: {e!s}"
raise DocumentStoreError(msg) from e

def delete_all_documents(self, *, recreate_index: bool = False) -> None:
"""
Deletes all documents in the document store.
Expand Down Expand Up @@ -511,7 +739,7 @@ def search(
self._ensure_initialized()
assert self._collection is not None

kwargs = self._prepare_query_kwargs(filters)
kwargs = ChromaDocumentStore._prepare_query_kwargs(filters)
results = self._collection.query(
query_texts=queries,
n_results=top_k,
Expand Down Expand Up @@ -539,7 +767,7 @@ async def search_async(
await self._ensure_initialized_async()
assert self._async_collection is not None

kwargs = self._prepare_query_kwargs(filters)
kwargs = ChromaDocumentStore._prepare_query_kwargs(filters)
results = await self._async_collection.query(
query_texts=queries,
n_results=top_k,
Expand Down Expand Up @@ -567,7 +795,7 @@ def search_embeddings(
self._ensure_initialized()
assert self._collection is not None

kwargs = self._prepare_query_kwargs(filters)
kwargs = ChromaDocumentStore._prepare_query_kwargs(filters)
results = self._collection.query(
query_embeddings=cast(list[Sequence[float]], query_embeddings),
n_results=top_k,
Expand Down Expand Up @@ -598,7 +826,7 @@ async def search_embeddings_async(
await self._ensure_initialized_async()
assert self._async_collection is not None

kwargs = self._prepare_query_kwargs(filters)
kwargs = ChromaDocumentStore._prepare_query_kwargs(filters)
results = await self._async_collection.query(
query_embeddings=cast(list[Sequence[float]], query_embeddings),
n_results=top_k,
Expand Down
4 changes: 4 additions & 0 deletions integrations/chroma/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
Expand Down
Loading