diff --git a/integrations/opensearch/pyproject.toml b/integrations/opensearch/pyproject.toml index 5eaa8bae9..78c5ab562 100644 --- a/integrations/opensearch/pyproject.toml +++ b/integrations/opensearch/pyproject.toml @@ -25,7 +25,9 @@ classifiers = [ ] dependencies = [ "haystack-ai>=2.14.0", - "opensearch-py[async]>=2.4.0,<3"] + "opensearch-py[async]>=2.4.0,<3", + "httpx>=0.28.1" +] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/opensearch#readme" diff --git a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py index 2b8d6694b..45cf600b1 100644 --- a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py +++ b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py @@ -8,12 +8,15 @@ from math import exp from typing import Any, Literal, Optional, Union +import httpx +import requests from haystack import default_from_dict, default_to_dict, logging from haystack.dataclasses import Document from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy from haystack.utils.auth import Secret from opensearchpy import AsyncHttpConnection, AsyncOpenSearch, OpenSearch +from opensearchpy.exceptions import SerializationError from opensearchpy.helpers import async_bulk, bulk from haystack_integrations.document_stores.opensearch.auth import AsyncAWSAuth, AWSAuth @@ -21,9 +24,12 @@ logger = logging.getLogger(__name__) +SPECIAL_FIELDS = {"content", "embedding", "id", "score", "sparse_embedding", "blob"} Hosts = Union[str, list[Union[str, Mapping[str, Union[str, int]]]]] +ResponseFormat = Literal["json", "jdbc", "csv", "raw"] + # document scores are essentially unbounded and will be scaled to values between 0 and 1 if scale_score is set to # True. Scaling uses the expit function (inverse of the logit function) after applying a scaling factor # (e.g., BM25_SCALING_FACTOR for the bm25_retrieval method). @@ -1141,3 +1147,511 @@ def _render_custom_query(self, custom_query: Any, substitutions: dict[str, Any]) return substitutions.get(custom_query, custom_query) return custom_query + + def count_documents_by_filter(self, filters: dict) -> int: + """ + Returns the number of documents that match the provided filters. + + :param filters: The filters to apply to count documents. + For filter syntax, see [Haystack metadata filtering](https://docs.haystack.deepset.ai/docs/metadata-filtering) + :returns: The number of documents that match the filters. + """ + self._ensure_initialized() + assert self._client is not None + + normalized_filters = normalize_filters(filters) + body = {"query": {"bool": {"filter": normalized_filters}}} + return self._client.count(index=self._index, body=body)["count"] + + async def count_documents_by_filter_async(self, filters: dict) -> int: + """ + Asynchronously returns the number of documents that match the provided filters. + + :param filters: The filters to apply to count documents. + For filter syntax, see [Haystack metadata filtering](https://docs.haystack.deepset.ai/docs/metadata-filtering) + :returns: The number of documents that match the filters. + """ + await self._ensure_initialized_async() + assert self._async_client is not None + + normalized_filters = normalize_filters(filters) + body = {"query": {"bool": {"filter": normalized_filters}}} + return (await self._async_client.count(index=self._index, body=body))["count"] + + @staticmethod + def _build_cardinality_aggregations(index_mapping: dict[str, Any]) -> dict[str, Any]: + """ + Builds cardinality aggregations for all metadata fields in the index mapping. + + See: https://docs.opensearch.org/latest/aggregations/metric/cardinality/ + """ + aggs = {} + for field_name in index_mapping.keys(): + if field_name not in SPECIAL_FIELDS: + aggs[f"{field_name}_cardinality"] = {"cardinality": {"field": field_name}} + return aggs + + @staticmethod + def _build_distinct_values_query_body(filters: dict, aggs: dict[str, Any]) -> dict[str, Any]: + """ + Builds the query body for distinct values counting with filters and aggregations. + """ + if filters: + normalized_filters = normalize_filters(filters) + return { + "query": {"bool": {"filter": normalized_filters}}, + "aggs": aggs, + "size": 0, # We only need aggregations, not documents + } + else: + # No filters - match all documents + return { + "query": {"match_all": {}}, + "aggs": aggs, + "size": 0, # We only need aggregations, not documents + } + + @staticmethod + def _extract_distinct_counts_from_aggregations( + aggregations: dict[str, Any], index_mapping: dict[str, Any] + ) -> dict[str, int]: + """ + Extracts distinct value counts from search result aggregations. + """ + distinct_counts = {} + for field_name in index_mapping.keys(): + if field_name not in SPECIAL_FIELDS: + agg_key = f"{field_name}_cardinality" + if agg_key in aggregations: + distinct_counts[field_name] = aggregations[agg_key]["value"] + return distinct_counts + + def count_distinct_values_by_filter(self, filters: dict) -> dict[str, int]: + """ + Returns the number of unique values for each meta field of the documents that match the provided filters. + + :param filters: The filters to apply to count documents. + For filter syntax, see [Haystack metadata filtering](https://docs.haystack.deepset.ai/docs/metadata-filtering) + :returns: The number of unique values for each meta field of the documents that match the filters. + """ + self._ensure_initialized() + assert self._client is not None + + # use index mapping to get all fields + mapping = self._client.indices.get_mapping(index=self._index) + index_mapping = mapping[self._index]["mappings"]["properties"] + + # build aggregations for each metadata field + aggs = self._build_cardinality_aggregations(index_mapping) + if not aggs: + return {} + + # build and execute search query + body = self._build_distinct_values_query_body(filters, aggs) + result = self._client.search(index=self._index, body=body) + + # extract cardinality values from aggregations + return self._extract_distinct_counts_from_aggregations(result.get("aggregations", {}), index_mapping) + + async def count_distinct_values_by_filter_async(self, filters: dict) -> dict[str, int]: + """ + Asynchronously returns the number of unique values for each meta field of the documents that match the + provided filters. + + :param filters: The filters to apply to count documents. + For filter syntax, see [Haystack metadata filtering](https://docs.haystack.deepset.ai/docs/metadata-filtering) + :returns: The number of unique values for each meta field of the documents that match the filters. + """ + await self._ensure_initialized_async() + assert self._async_client is not None + + # use index mapping to get all fields + mapping = await self._async_client.indices.get_mapping(index=self._index) + index_mapping = mapping[self._index]["mappings"]["properties"] + + # build aggregations for each metadata field + aggs = self._build_cardinality_aggregations(index_mapping) + if not aggs: + return {} + + # build and execute search query + body = self._build_distinct_values_query_body(filters, aggs) + result = await self._async_client.search(index=self._index, body=body) + + # extract cardinality values from aggregations + return self._extract_distinct_counts_from_aggregations(result.get("aggregations", {}), index_mapping) + + def get_fields_info(self) -> dict[str, dict]: + """ + Returns the information about the fields in the index. + + :returns: The information about the fields in the index. + """ + self._ensure_initialized() + assert self._client is not None + + mapping = self._client.indices.get_mapping(index=self._index) + index_mapping = mapping[self._index]["mappings"]["properties"] + return index_mapping + + async def get_fields_info_async(self) -> dict[str, dict]: + """ + Asynchronously returns the information about the fields in the index. + + :returns: The information about the fields in the index. + """ + await self._ensure_initialized_async() + assert self._async_client is not None + + mapping = await self._async_client.indices.get_mapping(index=self._index) + index_mapping = mapping[self._index]["mappings"]["properties"] + return index_mapping + + @staticmethod + def _normalize_metadata_field_name(metadata_field: str) -> str: + """ + Normalizes a metadata field name by removing the "meta." prefix if present. + """ + return metadata_field[5:] if metadata_field.startswith("meta.") else metadata_field + + @staticmethod + def _build_min_max_query_body(field_name: str) -> dict[str, Any]: + """ + Builds the query body for getting min and max values using stats aggregation. + """ + return { + "query": {"match_all": {}}, + "aggs": { + "field_stats": { + "stats": { + "field": field_name, + } + } + }, + "size": 0, # We only need aggregations, not documents + } + + @staticmethod + def _extract_min_max_from_stats(stats: dict[str, Any]) -> dict[str, Any]: + """ + Extracts min and max values from stats aggregation results. + """ + min_value = stats.get("min") + max_value = stats.get("max") + return {"min": min_value, "max": max_value} + + def get_field_min_max(self, metadata_field: str) -> dict[str, Any]: + """ + Returns the minimum and maximum values for the given metadata field. + + :param metadata_field: The metadata field to get the minimum and maximum values for. + :returns: The minimum and maximum values for the given metadata field. + """ + self._ensure_initialized() + assert self._client is not None + + field_name = self._normalize_metadata_field_name(metadata_field) + body = self._build_min_max_query_body(field_name) + result = self._client.search(index=self._index, body=body) + stats = result.get("aggregations", {}).get("field_stats", {}) + + return self._extract_min_max_from_stats(stats) + + async def get_field_min_max_async(self, metadata_field: str) -> dict[str, Any]: + """ + Asynchronously returns the minimum and maximum values for the given metadata field. + + :param metadata_field: The metadata field to get the minimum and maximum values for. + :returns: The minimum and maximum values for the given metadata field. + """ + await self._ensure_initialized_async() + assert self._async_client is not None + + field_name = self._normalize_metadata_field_name(metadata_field) + body = self._build_min_max_query_body(field_name) + result = await self._async_client.search(index=self._index, body=body) + stats = result.get("aggregations", {}).get("field_stats", {}) + + return self._extract_min_max_from_stats(stats) + + def get_field_unique_values( + self, metadata_field: str, search_term: str | None, from_: int, size: int + ) -> tuple[list[str], int]: + """ + Returns unique values for a metadata field, optionally filtered by a search term in the content. + + :param metadata_field: The metadata field to get unique values for. + :param search_term: Optional search term to filter documents by matching in the content field. + :param from_: The starting index for pagination. + :param size: The number of unique values to return. + :returns: A tuple containing (list of unique values, total count of unique values). + """ + self._ensure_initialized() + assert self._client is not None + + field_name = self._normalize_metadata_field_name(metadata_field) + + # filter by search_term if provided + query: dict[str, Any] = {"match_all": {}} + if search_term: + # Use match_phrase for exact phrase matching to avoid tokenization issues + query = {"match_phrase": {"content": search_term}} + + # Build aggregations + # Terms aggregation for paginated unique values + # Note: Terms aggregation doesn't support 'from' parameter directly, + # so we fetch from_ + size results and slice them + # Cardinality aggregation for total count + terms_size = from_ + size if from_ > 0 else size + body = { + "query": query, + "aggs": { + "unique_values": { + "terms": { + "field": field_name, + "size": terms_size, + } + }, + "total_count": { + "cardinality": { + "field": field_name, + } + }, + }, + "size": 0, # we only need aggregations, not documents + } + + result = self._client.search(index=self._index, body=body) + aggregations = result.get("aggregations", {}) + + # Extract unique values from terms aggregation buckets + unique_values_buckets = aggregations.get("unique_values", {}).get("buckets", []) + # Apply pagination by slicing the results + paginated_buckets = unique_values_buckets[from_ : from_ + size] + unique_values = [str(bucket["key"]) for bucket in paginated_buckets] + + # Extract total count from cardinality aggregation + total_count = int(aggregations.get("total_count", {}).get("value", 0)) + + return unique_values, total_count + + async def get_field_unique_values_async( + self, metadata_field: str, search_term: str | None, from_: int, size: int + ) -> tuple[list[str], int]: + """ + Asynchronously returns unique values for a metadata field, optionally filtered by a search term in the content. + + :param metadata_field: The metadata field to get unique values for. + :param search_term: Optional search term to filter documents by matching in the content field. + :param from_: The starting index for pagination. + :param size: The number of unique values to return. + :returns: A tuple containing (list of unique values, total count of unique values). + """ + await self._ensure_initialized_async() + assert self._async_client is not None + + field_name = self._normalize_metadata_field_name(metadata_field) + + # filter by search_term if provided + query: dict[str, Any] = {"match_all": {}} + if search_term: + # Use match_phrase for exact phrase matching to avoid tokenization issues + query = {"match_phrase": {"content": search_term}} + + # Build aggregations + # Terms aggregation for paginated unique values + # Note: Terms aggregation doesn't support 'from' parameter directly, + # so we fetch from_ + size results and slice them + # Cardinality aggregation for total count + terms_size = from_ + size if from_ > 0 else size + body = { + "query": query, + "aggs": { + "unique_values": { + "terms": { + "field": field_name, + "size": terms_size, + } + }, + "total_count": { + "cardinality": { + "field": field_name, + } + }, + }, + "size": 0, # we only need aggregations, not documents + } + + result = await self._async_client.search(index=self._index, body=body) + aggregations = result.get("aggregations", {}) + + # Extract unique values from terms aggregation buckets + unique_values_buckets = aggregations.get("unique_values", {}).get("buckets", []) + # Apply pagination by slicing the results + paginated_buckets = unique_values_buckets[from_ : from_ + size] + unique_values = [str(bucket["key"]) for bucket in paginated_buckets] + + # Extract total count from cardinality aggregation + total_count = int(aggregations.get("total_count", {}).get("value", 0)) + + return unique_values, total_count + + def _prepare_sql_http_request_params( + self, base_url: str, response_format: ResponseFormat + ) -> tuple[str, dict[str, str], Any]: + """ + Prepares HTTP request parameters for SQL query execution. + """ + url = f"{base_url}/_plugins/_sql?format={response_format}" + headers = {"Content-Type": "application/json"} + auth = None + if self._http_auth: + if isinstance(self._http_auth, tuple): + auth = self._http_auth + elif isinstance(self._http_auth, AWSAuth): + # For AWS auth, we need to use the opensearchpy client + # Fall through to the try/except below + pass + return url, headers, auth + + @staticmethod + def _process_sql_response(response_data: Any, response_format: ResponseFormat) -> Any: + """ + Processes the SQL query response data. + """ + if response_format == "json": + # extract only the query results + if isinstance(response_data, dict) and "hits" in response_data: + hits = response_data.get("hits", {}).get("hits", []) + # extract _source from each hit, which contains the actual document data + return [hit.get("_source", {}) for hit in hits] + return response_data + else: + return response_data if isinstance(response_data, str) else str(response_data) + + def query_sql(self, query: str, response_format: ResponseFormat = "json") -> Any: + """ + Execute a raw OpenSearch SQL query against the index. + + :param query: The OpenSearch SQL query to execute + :param response_format: The format of the response. See https://docs.opensearch.org/latest/search-plugins/sql/response-formats/ + :returns: The query results in the specified format. For JSON format, returns a list of dictionaries + (the _source from each hit). For other formats (csv, jdbc, raw), returns the response as text. + + NOTE: For non-JSON formats (csv, jdbc, raw), use requests to make a raw HTTP request and get the text response + This avoids deserialization issues with the opensearchpy client. + """ + self._ensure_initialized() + assert self._client is not None + + # For non-JSON formats, use requests directly to avoid deserialization issues + if response_format != "json": + try: + # Get connection info from the transport + connection = self._client.transport.get_connection() + base_url = connection.host + url, headers, auth = self._prepare_sql_http_request_params(base_url, response_format) + + verify = self._verify_certs if self._verify_certs is not None else True + timeout = self._timeout if self._timeout is not None else 30.0 + response = requests.post( + url, + json={"query": query}, + headers=headers, + auth=auth, + verify=verify, + timeout=timeout, + ) + response.raise_for_status() + return response.text + except Exception as e: + # If requests fails (e.g., AWS auth), fall back to opensearchpy + # which will raise SerializationError that we can handle + logger.error(f"Failed to execute SQL query in OpenSearch: {e!s}") + + try: + body = {"query": query} + params = {"format": response_format} + + response_data = self._client.transport.perform_request( + method="POST", + url="/_plugins/_sql", + params=params, + body=body, + ) + + return self._process_sql_response(response_data, response_format) + except SerializationError: + # If we get here, it means requests failed above (likely AWS auth) and opensearchpy can't deserialize the + # response. Re-raise as DocumentStoreError with a helpful message + msg = ( + f"Failed to execute SQL query in OpenSearch: Unable to deserialize {response_format} response. " + f"This format may not be supported with the current authentication method." + ) + raise DocumentStoreError(msg) from None + except Exception as e: + msg = f"Failed to execute SQL query in OpenSearch: {e!s}" + raise DocumentStoreError(msg) from e + + async def query_sql_async(self, query: str, response_format: ResponseFormat = "json") -> Any: + """ + Asynchronously execute a raw OpenSearch SQL query against the index. + + :param query: The OpenSearch SQL query to execute + :param response_format: The format of the response. See https://docs.opensearch.org/latest/search-plugins/sql/response-formats/ + :returns: The query results in the specified format. For JSON format, returns a list of dictionaries + (the _source from each hit). For other formats (csv, jdbc, raw), returns the response as text. + + NOTE: For non-JSON formats (csv, jdbc, raw), use httpx AsyncClient to make a raw HTTP request and get the text + response. This avoids deserialization issues with the opensearchpy client. + """ + await self._ensure_initialized_async() + assert self._async_client is not None + + # For non-JSON formats, use httpx directly to avoid deserialization issues + if response_format != "json": + try: + # Get connection info from the transport + connection = self._async_client.transport.get_connection() + base_url = connection.host + url, headers, auth = self._prepare_sql_http_request_params(base_url, response_format) + + verify = self._verify_certs if self._verify_certs is not None else True + timeout = httpx.Timeout(self._timeout if self._timeout else 30.0) + + async with httpx.AsyncClient(verify=verify, timeout=timeout) as client: + response = await client.post( + url, + json={"query": query}, + headers=headers, + auth=auth, + ) + response.raise_for_status() + return response.text + except Exception as e: + logger.error(f"Failed to execute SQL query in OpenSearch: {e!s}") + + try: + body = {"query": query} + params = {"format": response_format} + + response_data = await self._async_client.transport.perform_request( + method="POST", + url="/_plugins/_sql", + params=params, + body=body, + ) + + return self._process_sql_response(response_data, response_format) + except SerializationError: + # If we get here, it means httpx failed above (likely AWS auth or not installed) and opensearchpy can't + # deserialize the response. Re-raise as DocumentStoreError with a helpful message + msg = ( + f"Failed to execute SQL query in OpenSearch: Unable to deserialize {response_format} response. " + f"This format may not be supported with the current authentication method. " + f"Consider installing httpx for better support." + ) + raise DocumentStoreError(msg) from None + except Exception as e: + msg = f"Failed to execute SQL query in OpenSearch: {e!s}" + raise DocumentStoreError(msg) from e diff --git a/integrations/opensearch/tests/test_bm25_retriever.py b/integrations/opensearch/tests/test_bm25_retriever.py index 03235c36b..d6bb6350e 100644 --- a/integrations/opensearch/tests/test_bm25_retriever.py +++ b/integrations/opensearch/tests/test_bm25_retriever.py @@ -424,6 +424,60 @@ def test_bm25_retriever_runtime_document_store_switching( assert len(results_1_again["documents"]) == 1 +@pytest.mark.integration +def test_bm25_retriever_document_structure_with_metadata(document_store): + """ + Test document structure with complex metadata (nested values, lists, etc.) + """ + docs = [ + Document( + content="Python is versatile", + meta={ + "category": "programming", + "tags": ["python", "general-purpose"], + "rating": 4.5, + "active": True, + "author": {"name": "John", "role": "developer"}, + }, + id="python_doc", + ), + Document( + content="JavaScript is dynamic", + meta={ + "category": "programming", + "tags": ["javascript", "web"], + "rating": 4.8, + "active": True, + }, + id="js_doc", + ), + ] + document_store.write_documents(docs, refresh=True) + retriever = OpenSearchBM25Retriever(document_store=document_store) + + results = retriever.run(query="programming", top_k=2) + assert len(results["documents"]) == 2 + + for doc in results["documents"]: + # Verify structure + assert hasattr(doc, "content") + assert hasattr(doc, "meta") + assert isinstance(doc.meta, dict) + + # Verify complex metadata is preserved + assert "category" in doc.meta + assert "tags" in doc.meta + assert isinstance(doc.meta["tags"], list) + assert "rating" in doc.meta + + # Verify document can be serialized/deserialized + doc_dict = doc.to_dict() + doc_from_dict = Document.from_dict(doc_dict) + assert doc_from_dict.content == doc.content + assert doc_from_dict.meta == doc.meta + assert doc_from_dict.id == doc.id + + @pytest.mark.asyncio @pytest.mark.integration async def test_bm25_retriever_async_runtime_document_store_switching( diff --git a/integrations/opensearch/tests/test_document_store.py b/integrations/opensearch/tests/test_document_store.py index c75a5ca0a..3e224c498 100644 --- a/integrations/opensearch/tests/test_document_store.py +++ b/integrations/opensearch/tests/test_document_store.py @@ -593,6 +593,263 @@ def test_update_by_filter(self, document_store: OpenSearchDocumentStore): assert len(draft_docs) == 1 assert draft_docs[0].meta["category"] == "B" + def test_count_documents_by_filter(self, document_store: OpenSearchDocumentStore): + docs = [ + Document(content="Doc 1", meta={"category": "A", "status": "active"}), + Document(content="Doc 2", meta={"category": "B", "status": "active"}), + Document(content="Doc 3", meta={"category": "A", "status": "inactive"}), + Document(content="Doc 4", meta={"category": "A", "status": "active"}), + ] + document_store.write_documents(docs) + assert document_store.count_documents() == 4 + + count_a = document_store.count_documents_by_filter( + filters={"field": "meta.category", "operator": "==", "value": "A"} + ) + assert count_a == 3 + + count_a_active = document_store.count_documents_by_filter( + filters={ + "operator": "AND", + "conditions": [ + {"field": "meta.category", "operator": "==", "value": "A"}, + {"field": "meta.status", "operator": "==", "value": "active"}, + ], + } + ) + assert count_a_active == 2 + + def test_count_distinct_values_by_filter(self, document_store: OpenSearchDocumentStore): + docs = [ + Document(content="Doc 1", meta={"category": "A", "status": "active", "priority": 1}), + Document(content="Doc 2", meta={"category": "B", "status": "active", "priority": 2}), + Document(content="Doc 3", meta={"category": "A", "status": "inactive", "priority": 1}), + Document(content="Doc 4", meta={"category": "A", "status": "active", "priority": 3}), + Document(content="Doc 5", meta={"category": "C", "status": "active", "priority": 2}), + ] + document_store.write_documents(docs) + assert document_store.count_documents() == 5 + + # Count distinct values for all documents + distinct_counts = document_store.count_distinct_values_by_filter(filters={}) + assert distinct_counts["category"] == 3 # A, B, C + assert distinct_counts["status"] == 2 # active, inactive + assert distinct_counts["priority"] == 3 # 1, 2, 3 + + # Count distinct values for documents with category="A" + distinct_counts_a = document_store.count_distinct_values_by_filter( + filters={"field": "meta.category", "operator": "==", "value": "A"} + ) + assert distinct_counts_a["category"] == 1 # Only A + assert distinct_counts_a["status"] == 2 # active, inactive + assert distinct_counts_a["priority"] == 2 # 1, 3 + + # Count distinct values for documents with status="active" + distinct_counts_active = document_store.count_distinct_values_by_filter( + filters={"field": "meta.status", "operator": "==", "value": "active"} + ) + assert distinct_counts_active["category"] == 3 # A, B, C + assert distinct_counts_active["status"] == 1 # Only active + assert distinct_counts_active["priority"] == 3 # 1, 2, 3 + + # Count distinct values with complex filter (category="A" AND status="active") + distinct_counts_a_active = document_store.count_distinct_values_by_filter( + filters={ + "operator": "AND", + "conditions": [ + {"field": "meta.category", "operator": "==", "value": "A"}, + {"field": "meta.status", "operator": "==", "value": "active"}, + ], + } + ) + assert distinct_counts_a_active["category"] == 1 # Only A + assert distinct_counts_a_active["status"] == 1 # Only active + assert distinct_counts_a_active["priority"] == 2 # 1, 3 + + def test_get_fields_info(self, document_store: OpenSearchDocumentStore): + docs = [ + Document(content="Doc 1", meta={"category": "A", "status": "active", "priority": 1}), + Document(content="Doc 2", meta={"category": "B", "status": "inactive"}), + ] + document_store.write_documents(docs) + + fields_info = document_store.get_fields_info() + + # Verify that fields_info contains expected fields + assert "content" in fields_info + assert "embedding" in fields_info + assert "category" in fields_info + assert "status" in fields_info + assert "priority" in fields_info + + # Verify field types + assert fields_info["content"]["type"] == "text" + assert fields_info["embedding"]["type"] == "knn_vector" + # Metadata fields should be keyword type (from dynamic templates) + assert fields_info["category"]["type"] == "keyword" + assert fields_info["status"]["type"] == "keyword" + assert fields_info["priority"]["type"] == "long" + + def test_get_field_min_max(self, document_store: OpenSearchDocumentStore): + # Test with integer values + docs = [ + Document(content="Doc 1", meta={"priority": 1, "age": 10}), + Document(content="Doc 2", meta={"priority": 5, "age": 20}), + Document(content="Doc 3", meta={"priority": 3, "age": 15}), + Document(content="Doc 4", meta={"priority": 10, "age": 5}), + Document(content="Doc 6", meta={"rating": 10.5}), + Document(content="Doc 7", meta={"rating": 20.3}), + Document(content="Doc 8", meta={"rating": 15.7}), + Document(content="Doc 9", meta={"rating": 5.2}), + ] + document_store.write_documents(docs) + + # Test with "meta." prefix for integer field + min_max_priority = document_store.get_field_min_max("meta.priority") + assert min_max_priority["min"] == 1 + assert min_max_priority["max"] == 10 + + # Test with "meta." prefix for another integer field + min_max_rating = document_store.get_field_min_max("meta.age") + assert min_max_rating["min"] == 5 + assert min_max_rating["max"] == 20 + + # Test with single value + single_doc = [Document(content="Doc 5", meta={"single_value": 42})] + document_store.write_documents(single_doc) + min_max_single = document_store.get_field_min_max("meta.single_value") + assert min_max_single["min"] == 42 + assert min_max_single["max"] == 42 + + # Test with float values + min_max_score = document_store.get_field_min_max("meta.rating") + assert min_max_score["min"] == pytest.approx(5.2) + assert min_max_score["max"] == pytest.approx(20.3) + + def test_get_field_unique_values(self, document_store: OpenSearchDocumentStore): + # Test with string values + docs = [ + Document(content="Python programming", meta={"category": "A", "language": "Python"}), + Document(content="Java programming", meta={"category": "B", "language": "Java"}), + Document(content="Python scripting", meta={"category": "A", "language": "Python"}), + Document(content="JavaScript development", meta={"category": "C", "language": "JavaScript"}), + Document(content="Python data science", meta={"category": "A", "language": "Python"}), + Document(content="Java backend", meta={"category": "B", "language": "Java"}), + ] + document_store.write_documents(docs) + + # Test getting all unique values without search term + unique_values, total_count = document_store.get_field_unique_values("meta.category", None, 0, 10) + assert set(unique_values) == {"A", "B", "C"} + assert total_count == 3 + + # Test with "meta." prefix + unique_languages, lang_count = document_store.get_field_unique_values("meta.language", None, 0, 10) + assert set(unique_languages) == {"Python", "Java", "JavaScript"} + assert lang_count == 3 + + # Test pagination - first page + unique_values_page1, total_count = document_store.get_field_unique_values("meta.category", None, 0, 2) + assert len(unique_values_page1) == 2 + assert total_count == 3 + assert all(val in ["A", "B", "C"] for val in unique_values_page1) + + # Test pagination - second page + unique_values_page2, total_count = document_store.get_field_unique_values("meta.category", None, 2, 2) + assert len(unique_values_page2) == 1 + assert total_count == 3 + assert unique_values_page2[0] in ["A", "B", "C"] + + # Test with search term - filter by content matching "Python" + unique_values_filtered, total_count = document_store.get_field_unique_values("meta.category", "Python", 0, 10) + assert set(unique_values_filtered) == {"A"} # Only category A has documents with "Python" in content + assert total_count == 1 + + # Test with search term - filter by content matching "Java" + unique_values_java, total_count = document_store.get_field_unique_values("meta.category", "Java", 0, 10) + assert set(unique_values_java) == {"B"} # Only category B has documents with "Java" in content + assert total_count == 1 + + # Test with integer values + int_docs = [ + Document(content="Doc 1", meta={"priority": 1}), + Document(content="Doc 2", meta={"priority": 2}), + Document(content="Doc 3", meta={"priority": 1}), + Document(content="Doc 4", meta={"priority": 3}), + ] + document_store.write_documents(int_docs) + unique_priorities, priority_count = document_store.get_field_unique_values("meta.priority", None, 0, 10) + assert set(unique_priorities) == {"1", "2", "3"} + assert priority_count == 3 + + # Test with search term on integer field + unique_priorities_filtered, priority_count = document_store.get_field_unique_values( + "meta.priority", "Doc 1", 0, 10 + ) + assert set(unique_priorities_filtered) == {"1"} + assert priority_count == 1 + + def test_query_sql(self, document_store: OpenSearchDocumentStore): + """ + Test executing SQL queries against the OpenSearch index. + """ + docs = [ + Document(content="Python programming", meta={"category": "A", "status": "active", "priority": 1}), + Document(content="Java programming", meta={"category": "B", "status": "active", "priority": 2}), + Document(content="Python scripting", meta={"category": "A", "status": "inactive", "priority": 3}), + Document(content="JavaScript development", meta={"category": "C", "status": "active", "priority": 1}), + ] + document_store.write_documents(docs, refresh=True) + + # Test SQL query with JSON format (default) + sql_query = ( + f"SELECT content, category, status, priority FROM {document_store._index} " # noqa: S608 + f"WHERE category = 'A' ORDER BY priority" + ) + result = document_store.query_sql(sql_query, response_format="json") + + # New format returns a list of dictionaries (the _source from each hit) + assert len(result) == 2 # Two documents with category A + assert isinstance(result, list) + assert all(isinstance(row, dict) for row in result) + + # Verify data contains expected values + categories = [row.get("category") for row in result] + assert all(cat == "A" for cat in categories) + + # Verify all expected fields are present + for row in result: + assert "content" in row + assert "category" in row + assert "status" in row + assert "priority" in row + + # Test SQL query with CSV format + result_csv = document_store.query_sql(sql_query, response_format="csv") + assert isinstance(result_csv, str) + assert "content" in result_csv + assert "category" in result_csv + + # Test SQL query with JDBC format + result_jdbc = document_store.query_sql(sql_query, response_format="jdbc") + # JDBC format can be dict or str depending on OpenSearch version + assert result_jdbc is not None + + # Test SQL query with RAW format + result_raw = document_store.query_sql(sql_query, response_format="raw") + assert isinstance(result_raw, str) + + # Test COUNT query + count_query = f"SELECT COUNT(*) as total FROM {document_store._index}" # noqa: S608 + count_result = document_store.query_sql(count_query, response_format="json") + # COUNT query may return different format, check it's a valid response + assert count_result is not None + + # Test error handling for invalid SQL query + invalid_query = "SELECT * FROM non_existent_index" + with pytest.raises(DocumentStoreError, match="Failed to execute SQL query"): + document_store.query_sql(invalid_query) + @pytest.mark.integration def test_write_with_routing(self, document_store: OpenSearchDocumentStore): """Test writing documents with routing metadata""" diff --git a/integrations/opensearch/tests/test_document_store_async.py b/integrations/opensearch/tests/test_document_store_async.py index bbff724b9..9c0fbe5b5 100644 --- a/integrations/opensearch/tests/test_document_store_async.py +++ b/integrations/opensearch/tests/test_document_store_async.py @@ -4,6 +4,7 @@ import pytest from haystack.dataclasses import Document +from haystack.document_stores.errors import DocumentStoreError from haystack.document_stores.types import DuplicatePolicy from haystack_integrations.document_stores.opensearch.document_store import OpenSearchDocumentStore @@ -236,6 +237,86 @@ async def test_filter_documents(self, document_store: OpenSearchDocumentStore): assert result[0].content == "2" assert result[0].meta["number"] == 100 + @pytest.mark.asyncio + async def test_count_documents_by_filter(self, document_store: OpenSearchDocumentStore): + filterable_docs = [ + Document(content="Doc 1", meta={"category": "A", "status": "active"}), + Document(content="Doc 2", meta={"category": "B", "status": "active"}), + Document(content="Doc 3", meta={"category": "A", "status": "inactive"}), + Document(content="Doc 4", meta={"category": "A", "status": "active"}), + ] + await document_store.write_documents_async(filterable_docs) + assert await document_store.count_documents_async() == 4 + + count_a = await document_store.count_documents_by_filter_async( + filters={"field": "meta.category", "operator": "==", "value": "A"} + ) + assert count_a == 3 + + count_active = await document_store.count_documents_by_filter_async( + filters={"field": "meta.status", "operator": "==", "value": "active"} + ) + assert count_active == 3 + + count_a_active = await document_store.count_documents_by_filter_async( + filters={ + "operator": "AND", + "conditions": [ + {"field": "meta.category", "operator": "==", "value": "A"}, + {"field": "meta.status", "operator": "==", "value": "active"}, + ], + } + ) + assert count_a_active == 2 + + @pytest.mark.asyncio + async def test_count_distinct_values_by_filter(self, document_store: OpenSearchDocumentStore): + filterable_docs = [ + Document(content="Doc 1", meta={"category": "A", "status": "active", "priority": 1}), + Document(content="Doc 2", meta={"category": "B", "status": "active", "priority": 2}), + Document(content="Doc 3", meta={"category": "A", "status": "inactive", "priority": 1}), + Document(content="Doc 4", meta={"category": "A", "status": "active", "priority": 3}), + Document(content="Doc 5", meta={"category": "C", "status": "active", "priority": 2}), + ] + await document_store.write_documents_async(filterable_docs) + assert await document_store.count_documents_async() == 5 + + # count distinct values for all documents + distinct_counts = await document_store.count_distinct_values_by_filter_async(filters={}) + assert distinct_counts["category"] == 3 # A, B, C + assert distinct_counts["status"] == 2 # active, inactive + assert distinct_counts["priority"] == 3 # 1, 2, 3 + + # count distinct values for documents with category="A" + distinct_counts_a = await document_store.count_distinct_values_by_filter_async( + filters={"field": "meta.category", "operator": "==", "value": "A"} + ) + assert distinct_counts_a["category"] == 1 # Only A + assert distinct_counts_a["status"] == 2 # active, inactive + assert distinct_counts_a["priority"] == 2 # 1, 3 + + # count distinct values for documents with status="active" + distinct_counts_active = await document_store.count_distinct_values_by_filter_async( + filters={"field": "meta.status", "operator": "==", "value": "active"} + ) + assert distinct_counts_active["category"] == 3 # A, B, C + assert distinct_counts_active["status"] == 1 # Only active + assert distinct_counts_active["priority"] == 3 # 1, 2, 3 + + # count distinct values with complex filter (category="A" AND status="active") + distinct_counts_a_active = await document_store.count_distinct_values_by_filter_async( + filters={ + "operator": "AND", + "conditions": [ + {"field": "meta.category", "operator": "==", "value": "A"}, + {"field": "meta.status", "operator": "==", "value": "active"}, + ], + } + ) + assert distinct_counts_a_active["category"] == 1 # Only A + assert distinct_counts_a_active["status"] == 1 # Only active + assert distinct_counts_a_active["priority"] == 2 # 1, 3 + @pytest.mark.asyncio async def test_delete_documents(self, document_store: OpenSearchDocumentStore): doc = Document(content="test doc") @@ -391,3 +472,201 @@ async def test_update_by_filter_async(self, document_store: OpenSearchDocumentSt ) assert len(draft_docs) == 1 assert draft_docs[0].meta["category"] == "B" + + @pytest.mark.asyncio + async def test_get_fields_info(self, document_store: OpenSearchDocumentStore): + filterable_docs = [ + Document(content="Doc 1", meta={"category": "A", "status": "active", "priority": 1}), + Document(content="Doc 2", meta={"category": "B", "status": "inactive"}), + ] + await document_store.write_documents_async(filterable_docs) + + fields_info = await document_store.get_fields_info_async() + + # Verify that fields_info contains expected fields + assert "content" in fields_info + assert "embedding" in fields_info + assert "category" in fields_info + assert "status" in fields_info + assert "priority" in fields_info + + # Verify field types + assert fields_info["content"]["type"] == "text" + assert fields_info["embedding"]["type"] == "knn_vector" + # Metadata fields should be keyword type (from dynamic templates) + assert fields_info["category"]["type"] == "keyword" + assert fields_info["status"]["type"] == "keyword" + assert fields_info["priority"]["type"] == "long" + + @pytest.mark.asyncio + async def test_get_field_min_max(self, document_store: OpenSearchDocumentStore): + # Test with integer values + docs = [ + Document(content="Doc 1", meta={"priority": 1, "age": 10}), + Document(content="Doc 2", meta={"priority": 5, "age": 20}), + Document(content="Doc 3", meta={"priority": 3, "age": 15}), + Document(content="Doc 4", meta={"priority": 10, "age": 5}), + Document(content="Doc 6", meta={"rating": 10.5}), + Document(content="Doc 7", meta={"rating": 20.3}), + Document(content="Doc 8", meta={"rating": 15.7}), + Document(content="Doc 9", meta={"rating": 5.2}), + ] + await document_store.write_documents_async(docs) + + # Test with "meta." prefix for integer field + min_max_priority = await document_store.get_field_min_max_async("meta.priority") + assert min_max_priority["min"] == 1 + assert min_max_priority["max"] == 10 + + # Test with "meta." prefix for another integer field + min_max_rating = await document_store.get_field_min_max_async("meta.age") + assert min_max_rating["min"] == 5 + assert min_max_rating["max"] == 20 + + # Test with single value + single_doc = [Document(content="Doc 5", meta={"single_value": 42})] + await document_store.write_documents_async(single_doc) + min_max_single = await document_store.get_field_min_max_async("meta.single_value") + assert min_max_single["min"] == 42 + assert min_max_single["max"] == 42 + + # Test with float values + min_max_score = await document_store.get_field_min_max_async("meta.rating") + assert min_max_score["min"] == pytest.approx(5.2) + assert min_max_score["max"] == pytest.approx(20.3) + + @pytest.mark.asyncio + async def test_get_field_unique_values(self, document_store: OpenSearchDocumentStore): + # Test with string values + docs = [ + Document(content="Python programming", meta={"category": "A", "language": "Python"}), + Document(content="Java programming", meta={"category": "B", "language": "Java"}), + Document(content="Python scripting", meta={"category": "A", "language": "Python"}), + Document(content="JavaScript development", meta={"category": "C", "language": "JavaScript"}), + Document(content="Python data science", meta={"category": "A", "language": "Python"}), + Document(content="Java backend", meta={"category": "B", "language": "Java"}), + ] + await document_store.write_documents_async(docs) + + # Test getting all unique values without search term + unique_values, total_count = await document_store.get_field_unique_values_async("meta.category", None, 0, 10) + assert set(unique_values) == {"A", "B", "C"} + assert total_count == 3 + + # Test with "meta." prefix + unique_languages, lang_count = await document_store.get_field_unique_values_async("meta.language", None, 0, 10) + assert set(unique_languages) == {"Python", "Java", "JavaScript"} + assert lang_count == 3 + + # Test pagination - first page + unique_values_page1, total_count = await document_store.get_field_unique_values_async( + "meta.category", None, 0, 2 + ) + assert len(unique_values_page1) == 2 + assert total_count == 3 + assert all(val in ["A", "B", "C"] for val in unique_values_page1) + + # Test pagination - second page + unique_values_page2, total_count = await document_store.get_field_unique_values_async( + "meta.category", None, 2, 2 + ) + assert len(unique_values_page2) == 1 + assert total_count == 3 + assert unique_values_page2[0] in ["A", "B", "C"] + + # Test with search term - filter by content matching "Python" + unique_values_filtered, total_count = await document_store.get_field_unique_values_async( + "meta.category", "Python", 0, 10 + ) + assert set(unique_values_filtered) == {"A"} # Only category A has documents with "Python" in content + assert total_count == 1 + + # Test with search term - filter by content matching "Java" + unique_values_java, total_count = await document_store.get_field_unique_values_async( + "meta.category", "Java", 0, 10 + ) + assert set(unique_values_java) == {"B"} # Only category B has documents with "Java" in content + assert total_count == 1 + + # Test with integer values + int_docs = [ + Document(content="Doc 1", meta={"priority": 1}), + Document(content="Doc 2", meta={"priority": 2}), + Document(content="Doc 3", meta={"priority": 1}), + Document(content="Doc 4", meta={"priority": 3}), + ] + await document_store.write_documents_async(int_docs) + unique_priorities, priority_count = await document_store.get_field_unique_values_async( + "meta.priority", None, 0, 10 + ) + assert set(unique_priorities) == {"1", "2", "3"} + assert priority_count == 3 + + # Test with search term on integer field + unique_priorities_filtered, priority_count = await document_store.get_field_unique_values_async( + "meta.priority", "Doc 1", 0, 10 + ) + assert set(unique_priorities_filtered) == {"1"} + assert priority_count == 1 + + @pytest.mark.asyncio + async def test_query_sql(self, document_store: OpenSearchDocumentStore): + """ + Test executing SQL queries against the OpenSearch index. + """ + docs = [ + Document(content="Python programming", meta={"category": "A", "status": "active", "priority": 1}), + Document(content="Java programming", meta={"category": "B", "status": "active", "priority": 2}), + Document(content="Python scripting", meta={"category": "A", "status": "inactive", "priority": 3}), + Document(content="JavaScript development", meta={"category": "C", "status": "active", "priority": 1}), + ] + await document_store.write_documents_async(docs, refresh=True) + + # Test SQL query with JSON format (default) + sql_query = ( + f"SELECT content, category, status, priority FROM {document_store._index} " # noqa: S608 + f"WHERE category = 'A' ORDER BY priority" + ) + result = await document_store.query_sql_async(sql_query, response_format="json") + + # New format returns a list of dictionaries (the _source from each hit) + assert len(result) == 2 # Two documents with category A + assert isinstance(result, list) + assert all(isinstance(row, dict) for row in result) + + # Verify data contains expected values + categories = [row.get("category") for row in result] + assert all(cat == "A" for cat in categories) + + # Verify all expected fields are present + for row in result: + assert "content" in row + assert "category" in row + assert "status" in row + assert "priority" in row + + # Test SQL query with CSV format + result_csv = await document_store.query_sql_async(sql_query, response_format="csv") + assert isinstance(result_csv, str) + assert "content" in result_csv + assert "category" in result_csv + + # Test SQL query with JDBC format + result_jdbc = await document_store.query_sql_async(sql_query, response_format="jdbc") + # JDBC format can be dict or str depending on OpenSearch version + assert result_jdbc is not None + + # Test SQL query with RAW format + result_raw = await document_store.query_sql_async(sql_query, response_format="raw") + assert isinstance(result_raw, str) + + # Test COUNT query + count_query = f"SELECT COUNT(*) as total FROM {document_store._index}" # noqa: S608 + count_result = await document_store.query_sql_async(count_query, response_format="json") + # COUNT query may return different format, check it's a valid response + assert count_result is not None + + # Test error handling for invalid SQL query + invalid_query = "SELECT * FROM non_existent_index" + with pytest.raises(DocumentStoreError, match="Failed to execute SQL query"): + await document_store.query_sql_async(invalid_query) diff --git a/integrations/opensearch/tests/test_embedding_retriever.py b/integrations/opensearch/tests/test_embedding_retriever.py index a01b7dc00..fe2f5865c 100644 --- a/integrations/opensearch/tests/test_embedding_retriever.py +++ b/integrations/opensearch/tests/test_embedding_retriever.py @@ -404,3 +404,51 @@ async def test_embedding_retriever_runtime_document_store_switching_async( python_query_embedding = [0.4, 0.4, 0.4] + [0.0] * 765 results_1_again = await retriever.run_async(query_embedding=python_query_embedding) assert "Python" in results_1_again["documents"][0].content + + +@pytest.mark.integration +def test_embedding_retriever_document_structure_with_metadata(document_store, test_documents_with_embeddings_1): + """ + Test that documents returned by embedding retriever have correct structure: + - Metadata fields are in doc.meta (not at top level) + - Special fields (content, embedding, id, score) are at top level + - All original metadata is preserved + """ + document_store.write_documents(test_documents_with_embeddings_1, refresh=True) + retriever = OpenSearchEmbeddingRetriever(document_store=document_store) + + # Query embedding to match functional programming languages + query_embedding = [0.2, 0.3, 0.4] + [0.0] * 765 + results = retriever.run(query_embedding=query_embedding, top_k=5) + + assert len(results["documents"]) > 0 + + for doc in results["documents"]: + # Verify special fields are at top level + assert hasattr(doc, "content") + assert isinstance(doc.content, str) + assert hasattr(doc, "id") + assert isinstance(doc.id, str) + assert hasattr(doc, "score") + assert doc.score is not None + assert hasattr(doc, "embedding") + assert isinstance(doc.embedding, list) + assert len(doc.embedding) == 768 + + # Verify metadata fields are in meta dict (not at top level) + assert hasattr(doc, "meta") + assert isinstance(doc.meta, dict) + + # Verify original metadata is preserved + assert "likes" in doc.meta + assert "language_type" in doc.meta + assert isinstance(doc.meta["likes"], int) + assert isinstance(doc.meta["language_type"], str) + + # Verify document can be serialized/deserialized + doc_dict = doc.to_dict() + doc_from_dict = Document.from_dict(doc_dict) + assert doc_from_dict.content == doc.content + assert doc_from_dict.meta == doc.meta + assert doc_from_dict.id == doc.id + assert doc_from_dict.embedding == doc.embedding