Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ class OpenSearchDocumentStore:

Usage example:
```python
from haystack_integrations.document_stores.opensearch import OpenSearchDocumentStore
from haystack_integrations.document_stores.opensearch import (
OpenSearchDocumentStore,
)
from haystack import Document

document_store = OpenSearchDocumentStore(hosts="localhost:9200")
Expand Down Expand Up @@ -420,6 +422,10 @@ def _prepare_bulk_write_request(
opensearch_actions = []
for doc in documents:
doc_dict = doc.to_dict()

# Extract routing from document metadata
doc_routing = doc_dict.pop("_routing", None)

if "sparse_embedding" in doc_dict:
sparse_embedding = doc_dict.pop("sparse_embedding", None)
if sparse_embedding:
Expand All @@ -429,13 +435,17 @@ def _prepare_bulk_write_request(
"The `sparse_embedding` field will be ignored.",
id=doc.id,
)
opensearch_actions.append(
{
"_op_type": action,
"_id": doc.id,
"_source": doc_dict,
}
)

action_dict = {
"_op_type": action,
"_id": doc.id,
"_source": doc_dict,
}

if doc_routing is not None:
action_dict["_routing"] = doc_routing

opensearch_actions.append(action_dict)

return {
"client": self._client if not is_async else self._async_client,
Expand Down Expand Up @@ -549,18 +559,36 @@ def _deserialize_document(hit: dict[str, Any]) -> Document:
return Document.from_dict(data)

def _prepare_bulk_delete_request(
self, *, document_ids: list[str], is_async: bool, refresh: Literal["wait_for", True, False]
self,
*,
document_ids: list[str],
is_async: bool,
refresh: Literal["wait_for", True, False],
routing: Optional[dict[str, str]] = None,
) -> dict[str, Any]:
def action_generator():
for id_ in document_ids:
action = {"_op_type": "delete", "_id": id_}
# Add routing if provided for this document ID
if routing and id_ in routing and routing[id_] is not None:
action["_routing"] = routing[id_]
yield action

return {
"client": self._client if not is_async else self._async_client,
"actions": ({"_op_type": "delete", "_id": id_} for id_ in document_ids),
"actions": action_generator(),
"refresh": refresh,
"index": self._index,
"raise_on_error": False,
"max_chunk_bytes": self._max_chunk_bytes,
}

def delete_documents(self, document_ids: list[str], refresh: Literal["wait_for", True, False] = "wait_for") -> None:
def delete_documents(
self,
document_ids: list[str],
refresh: Literal["wait_for", True, False] = "wait_for",
routing: Optional[dict[str, str]] = None,
) -> None:
"""
Deletes documents that match the provided `document_ids` from the document store.

Expand All @@ -570,16 +598,24 @@ def delete_documents(self, document_ids: list[str], refresh: Literal["wait_for",
- `False`: Do not refresh (better performance for bulk operations).
- `"wait_for"`: Wait for the next refresh cycle (default, ensures read-your-writes consistency).
For more details, see the [OpenSearch refresh documentation](https://opensearch.org/docs/latest/api-reference/document-apis/index-document/).
:param routing: A dictionary mapping document IDs to their routing values.
Routing values are used to determine the shard where documents are stored.
If provided, the routing value for each document will be used during deletion.
"""

self._ensure_initialized()

bulk(**self._prepare_bulk_delete_request(document_ids=document_ids, is_async=False, refresh=refresh))
bulk(
**self._prepare_bulk_delete_request(
document_ids=document_ids, is_async=False, refresh=refresh, routing=routing
)
)

async def delete_documents_async(
self,
document_ids: list[str],
refresh: Literal["wait_for", True, False] = "wait_for",
routing: Optional[dict[str, str]] = None,
) -> None:
"""
Asynchronously deletes documents that match the provided `document_ids` from the document store.
Expand All @@ -590,11 +626,18 @@ async def delete_documents_async(
- `False`: Do not refresh (better performance for bulk operations).
- `"wait_for"`: Wait for the next refresh cycle (default, ensures read-your-writes consistency).
For more details, see the [OpenSearch refresh documentation](https://opensearch.org/docs/latest/api-reference/document-apis/index-document/).
:param routing: A dictionary mapping document IDs to their routing values.
Routing values are used to determine the shard where documents are stored.
If provided, the routing value for each document will be used during deletion.
"""
await self._ensure_initialized_async()
assert self._async_client is not None

await async_bulk(**self._prepare_bulk_delete_request(document_ids=document_ids, is_async=True, refresh=refresh))
await async_bulk(
**self._prepare_bulk_delete_request(
document_ids=document_ids, is_async=True, refresh=refresh, routing=routing
)
)

def _prepare_delete_all_request(self, *, refresh: bool) -> dict[str, Any]:
return {
Expand Down
82 changes: 82 additions & 0 deletions integrations/opensearch/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,46 @@ def test_get_default_mappings(_mock_opensearch_client):
}


@patch("haystack_integrations.document_stores.opensearch.document_store.bulk")
def test_routing_extracted_from_metadata(mock_bulk, document_store):
"""Test routing extraction from document metadata"""
mock_bulk.return_value = (2, [])

docs = [
Document(id="1", content="Doc", meta={"_routing": "user_a", "other": "data"}),
Document(id="2", content="Doc"),
]
document_store.write_documents(docs)

actions = list(mock_bulk.call_args.kwargs["actions"])

# Routing should be at action level, not in _source
assert actions[0]["_routing"] == "user_a"
assert "_routing" not in actions[0]["_source"].get("meta", {})

# Other metadata should be preserved
assert actions[0]["_source"]["other"] == "data"

# Second doc has no routing
assert "_routing" not in actions[1]
assert "_routing" not in actions[1]["_source"].get("meta", {})


@patch("haystack_integrations.document_stores.opensearch.document_store.bulk")
def test_routing_in_delete(mock_bulk, document_store):
"""Test routing parameter in delete operations"""
mock_bulk.return_value = (2, [])

routing_map = {"1": "user_a", "2": "user_b"}
document_store.delete_documents(["1", "2", "3"], routing=routing_map)

actions = list(mock_bulk.call_args.kwargs["actions"])

assert actions[0]["_routing"] == "user_a"
assert actions[1]["_routing"] == "user_b"
assert "_routing" not in actions[2]


@pytest.mark.integration
class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest):
"""
Expand Down Expand Up @@ -574,3 +614,45 @@ def test_update_by_filter(self, document_store: OpenSearchDocumentStore):
)
assert len(draft_docs) == 1
assert draft_docs[0].meta["category"] == "B"

@pytest.mark.integration
def test_write_with_routing(self, document_store: OpenSearchDocumentStore):
"""Test writing documents with routing metadata"""
docs = [
Document(id="1", content="User A doc", meta={"_routing": "user_a", "category": "test"}),
Document(id="2", content="User B doc", meta={"_routing": "user_b"}),
Document(id="3", content="No routing"),
]

written = document_store.write_documents(docs)
assert written == 3
assert document_store.count_documents() == 3

# Verify _routing not stored in metadata
retrieved = document_store.filter_documents()
retrieved_by_id = {doc.id: doc for doc in retrieved}

# Check _routing is not stored for any document
for doc in retrieved:
assert "_routing" not in doc.meta

assert retrieved_by_id["1"].meta["category"] == "test"

assert retrieved_by_id["2"].meta == {}

assert retrieved_by_id["3"].meta == {}

@pytest.mark.integration
def test_delete_with_routing(self, document_store: OpenSearchDocumentStore):
"""Test deleting documents with routing"""
docs = [
Document(id="1", content="Doc 1", meta={"_routing": "user_a"}),
Document(id="2", content="Doc 2", meta={"_routing": "user_b"}),
Document(id="3", content="Doc 3"),
]
document_store.write_documents(docs)

routing_map = {"1": "user_a", "2": "user_b"}
document_store.delete_documents(["1", "2"], routing=routing_map)

assert document_store.count_documents() == 1
39 changes: 39 additions & 0 deletions integrations/opensearch/tests/test_document_store_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,45 @@ async def test_delete_all_documents_no_index_recreation(self, document_store: Op
assert len(results) == 1
assert results[0].content == "New document after delete all"

@pytest.mark.asyncio
async def test_write_with_routing(self, document_store: OpenSearchDocumentStore):
"""Test async writing documents with routing metadata"""
docs = [
Document(id="1", content="User A doc", meta={"_routing": "user_a", "category": "test"}),
Document(id="2", content="User B doc", meta={"_routing": "user_b"}),
Document(id="3", content="No routing"),
]

written = await document_store.write_documents_async(docs)
assert written == 3
assert await document_store.count_documents_async() == 3

# Verify _routing not stored in metadata
retrieved = await document_store.filter_documents_async()
retrieved_by_id = {doc.id: doc for doc in retrieved}

# Check _routing is not stored for any document
for doc in retrieved:
assert "_routing" not in doc.meta

assert retrieved_by_id["1"].meta["category"] == "test"
assert retrieved_by_id["2"].meta == {}
assert retrieved_by_id["3"].meta == {}

@pytest.mark.asyncio
async def test_delete_with_routing(self, document_store: OpenSearchDocumentStore):
"""Test async deleting documents with routing"""
docs = [
Document(id="1", content="Doc 1", meta={"_routing": "user_a"}),
Document(id="2", content="Doc 2", meta={"_routing": "user_b"}),
Document(id="3", content="Doc 3"),
]
await document_store.write_documents_async(docs)

routing_map = {"1": "user_a", "2": "user_b"}
await document_store.delete_documents_async(["1", "2"], routing=routing_map)
assert await document_store.count_documents_async() == 1

async def test_delete_by_filter_async(self, document_store: OpenSearchDocumentStore):
docs = [
Document(content="Doc 1", meta={"category": "A"}),
Expand Down