diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py index 2742e8cb2..5bc6c6793 100644 --- a/examples/experimental/tests/test_sea_async_query.py +++ b/examples/experimental/tests/test_sea_async_query.py @@ -52,12 +52,20 @@ def test_sea_async_query_with_cloud_fetch(): f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - # Execute a simple query asynchronously + # Execute a query that generates large rows to force multiple chunks + requested_row_count = 5000 cursor = connection.cursor() + query = f""" + SELECT + id, + concat('value_', repeat('a', 10000)) as test_value + FROM range(1, {requested_row_count} + 1) AS t(id) + """ + logger.info( - "Executing asynchronous query with cloud fetch: SELECT 1 as test_value" + f"Executing asynchronous query with cloud fetch to generate {requested_row_count} rows" ) - cursor.execute_async("SELECT 1 as test_value") + cursor.execute_async(query) logger.info( "Asynchronous query submitted successfully with cloud fetch enabled" ) @@ -70,8 +78,25 @@ def test_sea_async_query_with_cloud_fetch(): logger.info("Query is no longer pending, getting results...") cursor.get_async_execution_result() + + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + actual_row_count = len(results) + + logger.info( + f"Requested {requested_row_count} rows, received {actual_row_count} rows" + ) + + # Verify total row count + if actual_row_count != requested_row_count: + logger.error( + f"FAIL: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}" + ) + return False + logger.info( - "Successfully retrieved asynchronous query results with cloud fetch enabled" + "PASS: Received correct number of rows with cloud fetch and all fetch methods work correctly" ) # Close resources @@ -131,12 +156,20 @@ def test_sea_async_query_without_cloud_fetch(): f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - # Execute a simple query asynchronously + # For non-cloud fetch, use a smaller row count to avoid exceeding inline limits + requested_row_count = 100 cursor = connection.cursor() + query = f""" + SELECT + id, + concat('value_', repeat('a', 100)) as test_value + FROM range(1, {requested_row_count} + 1) AS t(id) + """ + logger.info( - "Executing asynchronous query without cloud fetch: SELECT 1 as test_value" + f"Executing asynchronous query without cloud fetch to generate {requested_row_count} rows" ) - cursor.execute_async("SELECT 1 as test_value") + cursor.execute_async(query) logger.info( "Asynchronous query submitted successfully with cloud fetch disabled" ) @@ -149,8 +182,24 @@ def test_sea_async_query_without_cloud_fetch(): logger.info("Query is no longer pending, getting results...") cursor.get_async_execution_result() + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + actual_row_count = len(results) + + logger.info( + f"Requested {requested_row_count} rows, received {actual_row_count} rows" + ) + + # Verify total row count + if actual_row_count != requested_row_count: + logger.error( + f"FAIL: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}" + ) + return False + logger.info( - "Successfully retrieved asynchronous query results with cloud fetch disabled" + "PASS: Received correct number of rows without cloud fetch and all fetch methods work correctly" ) # Close resources diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py index 5ab6d823b..4e12d5aa4 100644 --- a/examples/experimental/tests/test_sea_sync_query.py +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -50,13 +50,34 @@ def test_sea_sync_query_with_cloud_fetch(): f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - # Execute a simple query + # Execute a query that generates large rows to force multiple chunks + requested_row_count = 10000 cursor = connection.cursor() + query = f""" + SELECT + id, + concat('value_', repeat('a', 10000)) as test_value + FROM range(1, {requested_row_count} + 1) AS t(id) + """ + + logger.info( + f"Executing synchronous query with cloud fetch to generate {requested_row_count} rows" + ) + cursor.execute(query) + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + actual_row_count = len(results) logger.info( - "Executing synchronous query with cloud fetch: SELECT 1 as test_value" + f"{actual_row_count} rows retrieved against {requested_row_count} requested" ) - cursor.execute("SELECT 1 as test_value") - logger.info("Query executed successfully with cloud fetch enabled") + + # Verify total row count + if actual_row_count != requested_row_count: + logger.error( + f"FAIL: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}" + ) + return False # Close resources cursor.close() @@ -115,13 +136,30 @@ def test_sea_sync_query_without_cloud_fetch(): f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - # Execute a simple query + # For non-cloud fetch, use a smaller row count to avoid exceeding inline limits + requested_row_count = 100 cursor = connection.cursor() logger.info( - "Executing synchronous query without cloud fetch: SELECT 1 as test_value" + f"Executing synchronous query without cloud fetch: SELECT {requested_row_count} rows" + ) + cursor.execute( + "SELECT id, 'test_value_' || CAST(id as STRING) as test_value FROM range(1, 101)" ) - cursor.execute("SELECT 1 as test_value") - logger.info("Query executed successfully with cloud fetch disabled") + + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + actual_row_count = len(results) + logger.info( + f"{actual_row_count} rows retrieved against {requested_row_count} requested" + ) + + # Verify total row count + if actual_row_count != requested_row_count: + logger.error( + f"FAIL: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}" + ) + return False # Close resources cursor.close() diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index c0b89da75..98cb9b2a8 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -5,7 +5,12 @@ import re from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set -from databricks.sql.backend.sea.models.base import ResultManifest, StatementStatus +from databricks.sql.backend.sea.models.base import ( + ExternalLink, + ResultManifest, + StatementStatus, +) +from databricks.sql.backend.sea.models.responses import GetChunksResponse from databricks.sql.backend.sea.utils.constants import ( ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, ResultFormat, @@ -19,7 +24,7 @@ if TYPE_CHECKING: from databricks.sql.client import Cursor -from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.sea.result_set import SeaResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import ( @@ -110,6 +115,7 @@ class SeaDatabricksClient(DatabricksClient): STATEMENT_PATH = BASE_PATH + "statements" STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}" CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel" + CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}" # SEA constants POLL_INTERVAL_SECONDS = 0.2 @@ -296,7 +302,7 @@ def close_session(self, session_id: SessionId) -> None: def _extract_description_from_manifest( self, manifest: ResultManifest - ) -> Optional[List]: + ) -> List[Tuple]: """ Extract column description from a manifest object, in the format defined by the spec: https://peps.python.org/pep-0249/#description @@ -311,9 +317,6 @@ def _extract_description_from_manifest( schema_data = manifest.schema columns_data = schema_data.get("columns", []) - if not columns_data: - return None - columns = [] for col_data in columns_data: # Format: (name, type_code, display_size, internal_size, precision, scale, null_ok) @@ -337,7 +340,7 @@ def _extract_description_from_manifest( ) ) - return columns if columns else None + return columns def _results_message_to_execute_response( self, response: Union[ExecuteStatementResponse, GetStatementResponse] @@ -358,7 +361,7 @@ def _results_message_to_execute_response( # Check for compression lz4_compressed = ( - response.manifest.result_compression == ResultCompression.LZ4_FRAME + response.manifest.result_compression == ResultCompression.LZ4_FRAME.value ) execute_response = ExecuteResponse( @@ -647,6 +650,27 @@ def get_execution_result( response = self._poll_query(command_id) return self._response_to_result_set(response, cursor) + def get_chunk_links( + self, statement_id: str, chunk_index: int + ) -> List[ExternalLink]: + """ + Get links for chunks starting from the specified index. + Args: + statement_id: The statement ID + chunk_index: The starting chunk index + Returns: + ExternalLink: External link for the chunk + """ + + response_data = self._http_client._make_request( + method="GET", + path=self.CHUNK_PATH_WITH_ID_AND_INDEX.format(statement_id, chunk_index), + ) + response = GetChunksResponse.from_dict(response_data) + + links = response.external_links or [] + return links + # == Metadata Operations == def get_catalogs( diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py index b899b791d..8450ec85d 100644 --- a/src/databricks/sql/backend/sea/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -26,6 +26,7 @@ ExecuteStatementResponse, GetStatementResponse, CreateSessionResponse, + GetChunksResponse, ) __all__ = [ @@ -47,4 +48,5 @@ "ExecuteStatementResponse", "GetStatementResponse", "CreateSessionResponse", + "GetChunksResponse", ] diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 75596ec9b..5a5580481 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -160,3 +160,37 @@ class CreateSessionResponse: def from_dict(cls, data: Dict[str, Any]) -> "CreateSessionResponse": """Create a CreateSessionResponse from a dictionary.""" return cls(session_id=data.get("session_id", "")) + + +@dataclass +class GetChunksResponse: + """ + Response from getting chunks for a statement. + + The response model can be found in the docs, here: + https://docs.databricks.com/api/workspace/statementexecution/getstatementresultchunkn + """ + + data: Optional[List[List[Any]]] = None + external_links: Optional[List[ExternalLink]] = None + byte_count: Optional[int] = None + chunk_index: Optional[int] = None + next_chunk_index: Optional[int] = None + next_chunk_internal_link: Optional[str] = None + row_count: Optional[int] = None + row_offset: Optional[int] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GetChunksResponse": + """Create a GetChunksResponse from a dictionary.""" + result = _parse_result({"result": data}) + return cls( + data=result.data, + external_links=result.external_links, + byte_count=result.byte_count, + chunk_index=result.chunk_index, + next_chunk_index=result.next_chunk_index, + next_chunk_internal_link=result.next_chunk_internal_link, + row_count=result.row_count, + row_offset=result.row_offset, + ) diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py new file mode 100644 index 000000000..130f0c5bf --- /dev/null +++ b/src/databricks/sql/backend/sea/queue.py @@ -0,0 +1,387 @@ +from __future__ import annotations + +from abc import ABC +import threading +from typing import Dict, List, Optional, Tuple, Union, TYPE_CHECKING + +from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager +from databricks.sql.telemetry.models.enums import StatementType + +from databricks.sql.cloudfetch.downloader import ResultSetDownloadHandler + +try: + import pyarrow +except ImportError: + pyarrow = None + +import dateutil + +if TYPE_CHECKING: + from databricks.sql.backend.sea.backend import SeaDatabricksClient + from databricks.sql.backend.sea.models.base import ( + ExternalLink, + ResultData, + ResultManifest, + ) +from databricks.sql.backend.sea.utils.constants import ResultFormat +from databricks.sql.exc import ProgrammingError, ServerOperationError +from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink +from databricks.sql.types import SSLOptions +from databricks.sql.utils import ( + ArrowQueue, + CloudFetchQueue, + ResultSetQueue, + create_arrow_table_from_arrow_file, +) + +import logging + +logger = logging.getLogger(__name__) + + +class SeaResultSetQueueFactory(ABC): + @staticmethod + def build_queue( + result_data: ResultData, + manifest: ResultManifest, + statement_id: str, + ssl_options: SSLOptions, + description: List[Tuple], + max_download_threads: int, + sea_client: SeaDatabricksClient, + lz4_compressed: bool, + ) -> ResultSetQueue: + """ + Factory method to build a result set queue for SEA backend. + + Args: + result_data (ResultData): Result data from SEA response + manifest (ResultManifest): Manifest from SEA response + statement_id (str): Statement ID for the query + description (List[List[Any]]): Column descriptions + max_download_threads (int): Maximum number of download threads + sea_client (SeaDatabricksClient): SEA client for fetching additional links + lz4_compressed (bool): Whether the data is LZ4 compressed + + Returns: + ResultSetQueue: The appropriate queue for the result data + """ + + if manifest.format == ResultFormat.JSON_ARRAY.value: + # INLINE disposition with JSON_ARRAY format + return JsonQueue(result_data.data) + elif manifest.format == ResultFormat.ARROW_STREAM.value: + if result_data.attachment is not None: + # direct results from Hybrid disposition + arrow_file = ( + ResultSetDownloadHandler._decompress_data(result_data.attachment) + if lz4_compressed + else result_data.attachment + ) + arrow_table = create_arrow_table_from_arrow_file( + arrow_file, description + ) + logger.debug(f"Created arrow table with {arrow_table.num_rows} rows") + return ArrowQueue(arrow_table, manifest.total_row_count) + + # EXTERNAL_LINKS disposition + return SeaCloudFetchQueue( + result_data=result_data, + max_download_threads=max_download_threads, + ssl_options=ssl_options, + sea_client=sea_client, + statement_id=statement_id, + total_chunk_count=manifest.total_chunk_count, + lz4_compressed=lz4_compressed, + description=description, + ) + raise ProgrammingError("Invalid result format") + + +class JsonQueue(ResultSetQueue): + """Queue implementation for JSON_ARRAY format data.""" + + def __init__(self, data_array: Optional[List[List[str]]]): + """Initialize with JSON array data.""" + self.data_array = data_array or [] + self.cur_row_index = 0 + self.num_rows = len(self.data_array) + + def next_n_rows(self, num_rows: int) -> List[List[str]]: + """Get the next n rows from the data array.""" + length = min(num_rows, self.num_rows - self.cur_row_index) + slice = self.data_array[self.cur_row_index : self.cur_row_index + length] + self.cur_row_index += length + return slice + + def remaining_rows(self) -> List[List[str]]: + """Get all remaining rows from the data array.""" + slice = self.data_array[self.cur_row_index :] + self.cur_row_index += len(slice) + return slice + + def close(self): + return + + +class LinkFetcher: + """ + Background helper that incrementally retrieves *external links* for a + result set produced by the SEA backend and feeds them to a + :class:`databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager`. + + The SEA backend splits large result sets into *chunks*. Each chunk is + stored remotely (e.g., in object storage) and exposed via a signed URL + encapsulated by an :class:`ExternalLink`. Only the first batch of links is + returned with the initial query response. The remaining links must be + pulled on demand using the *next-chunk* token embedded in each + :pyattr:`ExternalLink.next_chunk_index`. + + LinkFetcher takes care of this choreography so callers (primarily + ``SeaCloudFetchQueue``) can simply ask for the link of a specific + ``chunk_index`` and block until it becomes available. + + Key responsibilities: + + • Maintain an in-memory mapping from ``chunk_index`` → ``ExternalLink``. + • Launch a background worker thread that continuously requests the next + batch of links from the backend until all chunks have been discovered or + an unrecoverable error occurs. + • Bridge SEA link objects to the Thrift representation expected by the + existing download manager. + • Provide a synchronous API (`get_chunk_link`) that blocks until the desired + link is present in the cache. + """ + + def __init__( + self, + download_manager: ResultFileDownloadManager, + backend: SeaDatabricksClient, + statement_id: str, + initial_links: List[ExternalLink], + total_chunk_count: int, + ): + self.download_manager = download_manager + self.backend = backend + self._statement_id = statement_id + + self._shutdown_event = threading.Event() + + self._link_data_update = threading.Condition() + self._error: Optional[Exception] = None + self.chunk_index_to_link: Dict[int, ExternalLink] = {} + + self._add_links(initial_links) + self.total_chunk_count = total_chunk_count + + # DEBUG: capture initial state for observability + logger.debug( + "LinkFetcher[%s]: initialized with %d initial link(s); expecting %d total chunk(s)", + statement_id, + len(initial_links), + total_chunk_count, + ) + + def _add_links(self, links: List[ExternalLink]): + """Cache *links* locally and enqueue them with the download manager.""" + logger.debug( + "LinkFetcher[%s]: caching %d link(s) – chunks %s", + self._statement_id, + len(links), + ", ".join(str(l.chunk_index) for l in links) if links else "", + ) + for link in links: + self.chunk_index_to_link[link.chunk_index] = link + self.download_manager.add_link(LinkFetcher._convert_to_thrift_link(link)) + + def _get_next_chunk_index(self) -> Optional[int]: + """Return the next *chunk_index* that should be requested from the backend, or ``None`` if we have them all.""" + with self._link_data_update: + max_chunk_index = max(self.chunk_index_to_link.keys(), default=None) + if max_chunk_index is None: + return 0 + max_link = self.chunk_index_to_link[max_chunk_index] + return max_link.next_chunk_index + + def _trigger_next_batch_download(self) -> bool: + """Fetch the next batch of links from the backend and return *True* on success.""" + logger.debug( + "LinkFetcher[%s]: requesting next batch of links", self._statement_id + ) + next_chunk_index = self._get_next_chunk_index() + if next_chunk_index is None: + return False + + try: + links = self.backend.get_chunk_links(self._statement_id, next_chunk_index) + with self._link_data_update: + self._add_links(links) + self._link_data_update.notify_all() + except Exception as e: + logger.error( + f"LinkFetcher: Error fetching links for chunk {next_chunk_index}: {e}" + ) + with self._link_data_update: + self._error = e + self._link_data_update.notify_all() + return False + + logger.debug( + "LinkFetcher[%s]: received %d new link(s)", + self._statement_id, + len(links), + ) + return True + + def get_chunk_link(self, chunk_index: int) -> Optional[ExternalLink]: + """Return (blocking) the :class:`ExternalLink` associated with *chunk_index*.""" + logger.debug( + "LinkFetcher[%s]: waiting for link of chunk %d", + self._statement_id, + chunk_index, + ) + if chunk_index >= self.total_chunk_count: + return None + + with self._link_data_update: + while chunk_index not in self.chunk_index_to_link: + if self._error: + raise self._error + if self._shutdown_event.is_set(): + raise ProgrammingError( + "LinkFetcher is shutting down without providing link for chunk index {}".format( + chunk_index + ) + ) + self._link_data_update.wait() + + return self.chunk_index_to_link[chunk_index] + + @staticmethod + def _convert_to_thrift_link(link: ExternalLink) -> TSparkArrowResultLink: + """Convert SEA external links to Thrift format for compatibility with existing download manager.""" + # Parse the ISO format expiration time + expiry_time = int(dateutil.parser.parse(link.expiration).timestamp()) + return TSparkArrowResultLink( + fileLink=link.external_link, + expiryTime=expiry_time, + rowCount=link.row_count, + bytesNum=link.byte_count, + startRowOffset=link.row_offset, + httpHeaders=link.http_headers or {}, + ) + + def _worker_loop(self): + """Entry point for the background thread.""" + logger.debug("LinkFetcher[%s]: worker thread started", self._statement_id) + while not self._shutdown_event.is_set(): + links_downloaded = self._trigger_next_batch_download() + if not links_downloaded: + self._shutdown_event.set() + logger.debug("LinkFetcher[%s]: worker thread exiting", self._statement_id) + with self._link_data_update: + self._link_data_update.notify_all() + + def start(self): + """Spawn the worker thread.""" + logger.debug("LinkFetcher[%s]: starting worker thread", self._statement_id) + self._worker_thread = threading.Thread( + target=self._worker_loop, name=f"LinkFetcher-{self._statement_id}" + ) + self._worker_thread.start() + + def stop(self): + """Signal the worker thread to stop and wait for its termination.""" + logger.debug("LinkFetcher[%s]: stopping worker thread", self._statement_id) + self._shutdown_event.set() + self._worker_thread.join() + logger.debug("LinkFetcher[%s]: worker thread stopped", self._statement_id) + + +class SeaCloudFetchQueue(CloudFetchQueue): + """Queue implementation for EXTERNAL_LINKS disposition with ARROW format for SEA backend.""" + + def __init__( + self, + result_data: ResultData, + max_download_threads: int, + ssl_options: SSLOptions, + sea_client: SeaDatabricksClient, + statement_id: str, + total_chunk_count: int, + lz4_compressed: bool = False, + description: List[Tuple] = [], + ): + """ + Initialize the SEA CloudFetchQueue. + + Args: + initial_links: Initial list of external links to download + schema_bytes: Arrow schema bytes + max_download_threads: Maximum number of download threads + ssl_options: SSL options for downloads + sea_client: SEA client for fetching additional links + statement_id: Statement ID for the query + total_chunk_count: Total number of chunks in the result set + lz4_compressed: Whether the data is LZ4 compressed + description: Column descriptions + """ + + super().__init__( + max_download_threads=max_download_threads, + ssl_options=ssl_options, + statement_id=statement_id, + schema_bytes=None, + lz4_compressed=lz4_compressed, + description=description, + # TODO: fix these arguments when telemetry is implemented in SEA + session_id_hex=None, + chunk_id=0, + ) + + logger.debug( + "SeaCloudFetchQueue: Initialize CloudFetch loader for statement {}, total chunks: {}".format( + statement_id, total_chunk_count + ) + ) + + initial_links = result_data.external_links or [] + + # Track the current chunk we're processing + self._current_chunk_index = 0 + + self.link_fetcher = None # for empty responses, we do not need a link fetcher + if total_chunk_count > 0: + self.link_fetcher = LinkFetcher( + download_manager=self.download_manager, + backend=sea_client, + statement_id=statement_id, + initial_links=initial_links, + total_chunk_count=total_chunk_count, + ) + self.link_fetcher.start() + + # Initialize table and position + self.table = self._create_next_table() + + def _create_next_table(self) -> Union["pyarrow.Table", None]: + """Create next table by retrieving the logical next downloaded file.""" + if self.link_fetcher is None: + return None + + chunk_link = self.link_fetcher.get_chunk_link(self._current_chunk_index) + if chunk_link is None: + return None + + row_offset = chunk_link.row_offset + # NOTE: link has already been submitted to download manager at this point + arrow_table = self._create_table_at_offset(row_offset) + + self._current_chunk_index += 1 + + return arrow_table + + def close(self): + super().close() + if self.link_fetcher: + self.link_fetcher.stop() diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py new file mode 100644 index 000000000..a6a0a298b --- /dev/null +++ b/src/databricks/sql/backend/sea/result_set.py @@ -0,0 +1,266 @@ +from __future__ import annotations + +from typing import Any, List, Optional, TYPE_CHECKING + +import logging + +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest +from databricks.sql.backend.sea.utils.conversion import SqlTypeConverter + +try: + import pyarrow +except ImportError: + pyarrow = None + +if TYPE_CHECKING: + from databricks.sql.client import Connection + from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.types import Row +from databricks.sql.backend.sea.queue import JsonQueue, SeaResultSetQueueFactory +from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.result_set import ResultSet + +logger = logging.getLogger(__name__) + + +class SeaResultSet(ResultSet): + """ResultSet implementation for SEA backend.""" + + def __init__( + self, + connection: Connection, + execute_response: ExecuteResponse, + sea_client: SeaDatabricksClient, + result_data: ResultData, + manifest: ResultManifest, + buffer_size_bytes: int = 104857600, + arraysize: int = 10000, + ): + """ + Initialize a SeaResultSet with the response from a SEA query execution. + + Args: + connection: The parent connection + execute_response: Response from the execute command + sea_client: The SeaDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + result_data: Result data from SEA response + manifest: Manifest from SEA response + """ + + self.manifest = manifest + + statement_id = execute_response.command_id.to_sea_statement_id() + if statement_id is None: + raise ValueError("Command ID is not a SEA statement ID") + + results_queue = SeaResultSetQueueFactory.build_queue( + result_data, + self.manifest, + statement_id, + ssl_options=connection.session.ssl_options, + description=execute_response.description, + max_download_threads=sea_client.max_download_threads, + sea_client=sea_client, + lz4_compressed=execute_response.lz4_compressed, + ) + + # Call parent constructor with common attributes + super().__init__( + connection=connection, + backend=sea_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=execute_response.command_id, + status=execute_response.status, + has_been_closed_server_side=execute_response.has_been_closed_server_side, + results_queue=results_queue, + description=execute_response.description, + is_staging_operation=execute_response.is_staging_operation, + lz4_compressed=execute_response.lz4_compressed, + arrow_schema_bytes=execute_response.arrow_schema_bytes, + ) + + def _convert_json_types(self, row: List[str]) -> List[Any]: + """ + Convert string values in the row to appropriate Python types based on column metadata. + """ + + # JSON + INLINE gives us string values, so we convert them to appropriate + # types based on column metadata + converted_row = [] + + for i, value in enumerate(row): + column_type = self.description[i][1] + precision = self.description[i][4] + scale = self.description[i][5] + + try: + converted_value = SqlTypeConverter.convert_value( + value, column_type, precision=precision, scale=scale + ) + converted_row.append(converted_value) + except Exception as e: + logger.warning( + f"Error converting value '{value}' to {column_type}: {e}" + ) + converted_row.append(value) + + return converted_row + + def _convert_json_to_arrow_table(self, rows: List[List[str]]) -> "pyarrow.Table": + """ + Convert raw data rows to Arrow table. + + Args: + rows: List of raw data rows + + Returns: + PyArrow Table containing the converted values + """ + + if not rows: + return pyarrow.Table.from_pydict({}) + + # create a generator for row conversion + converted_rows_iter = (self._convert_json_types(row) for row in rows) + cols = list(map(list, zip(*converted_rows_iter))) + + names = [col[0] for col in self.description] + return pyarrow.Table.from_arrays(cols, names=names) + + def _create_json_table(self, rows: List[List[str]]) -> List[Row]: + """ + Convert raw data rows to Row objects with named columns based on description. + + Args: + rows: List of raw data rows + Returns: + List of Row objects with named columns and converted values + """ + + ResultRow = Row(*[col[0] for col in self.description]) + return [ResultRow(*self._convert_json_types(row)) for row in rows] + + def fetchmany_json(self, size: int) -> List[List[str]]: + """ + Fetch the next set of rows as a columnar table. + + Args: + size: Number of rows to fetch + + Returns: + Columnar table containing the fetched rows + + Raises: + ValueError: If size is negative + """ + + if size < 0: + raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") + + results = self.results.next_n_rows(size) + self._next_row_index += len(results) + + return results + + def fetchall_json(self) -> List[List[str]]: + """ + Fetch all remaining rows as a columnar table. + + Returns: + Columnar table containing all remaining rows + """ + + results = self.results.remaining_rows() + self._next_row_index += len(results) + + return results + + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + """ + Fetch the next set of rows as an Arrow table. + + Args: + size: Number of rows to fetch + + Returns: + PyArrow Table containing the fetched rows + + Raises: + ImportError: If PyArrow is not installed + ValueError: If size is negative + """ + + if size < 0: + raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") + + results = self.results.next_n_rows(size) + if isinstance(self.results, JsonQueue): + results = self._convert_json_to_arrow_table(results) + + self._next_row_index += results.num_rows + + return results + + def fetchall_arrow(self) -> "pyarrow.Table": + """ + Fetch all remaining rows as an Arrow table. + """ + + results = self.results.remaining_rows() + if isinstance(self.results, JsonQueue): + results = self._convert_json_to_arrow_table(results) + + self._next_row_index += results.num_rows + + return results + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + + Returns: + A single Row object or None if no more rows are available + """ + + if isinstance(self.results, JsonQueue): + res = self._create_json_table(self.fetchmany_json(1)) + else: + res = self._convert_arrow_table(self.fetchmany_arrow(1)) + + return res[0] if res else None + + def fetchmany(self, size: int) -> List[Row]: + """ + Fetch the next set of rows of a query result, returning a list of rows. + + Args: + size: Number of rows to fetch (defaults to arraysize if None) + + Returns: + List of Row objects + + Raises: + ValueError: If size is negative + """ + + if isinstance(self.results, JsonQueue): + return self._create_json_table(self.fetchmany_json(size)) + else: + return self._convert_arrow_table(self.fetchmany_arrow(size)) + + def fetchall(self) -> List[Row]: + """ + Fetch all remaining rows of a query result, returning them as a list of rows. + + Returns: + List of Row objects containing all remaining rows + """ + + if isinstance(self.results, JsonQueue): + return self._create_json_table(self.fetchall_json()) + else: + return self._convert_arrow_table(self.fetchall_arrow()) diff --git a/src/databricks/sql/backend/sea/utils/conversion.py b/src/databricks/sql/backend/sea/utils/conversion.py new file mode 100644 index 000000000..b2de97f5d --- /dev/null +++ b/src/databricks/sql/backend/sea/utils/conversion.py @@ -0,0 +1,160 @@ +""" +Type conversion utilities for the Databricks SQL Connector. + +This module provides functionality to convert string values from SEA Inline results +to appropriate Python types based on column metadata. +""" + +import datetime +import decimal +import logging +from dateutil import parser +from typing import Callable, Dict, Optional + +logger = logging.getLogger(__name__) + + +def _convert_decimal( + value: str, precision: Optional[int] = None, scale: Optional[int] = None +) -> decimal.Decimal: + """ + Convert a string value to a decimal with optional precision and scale. + + Args: + value: The string value to convert + precision: Optional precision (total number of significant digits) for the decimal + scale: Optional scale (number of decimal places) for the decimal + + Returns: + A decimal.Decimal object with appropriate precision and scale + """ + + # First create the decimal from the string value + result = decimal.Decimal(value) + + # Apply scale (quantize to specific number of decimal places) if specified + quantizer = None + if scale is not None: + quantizer = decimal.Decimal(f'0.{"0" * scale}') + + # Apply precision (total number of significant digits) if specified + context = None + if precision is not None: + context = decimal.Context(prec=precision) + + if quantizer is not None: + result = result.quantize(quantizer, context=context) + + return result + + +class SqlType: + """ + SQL type constants + + The list of types can be found in the SEA REST API Reference: + https://docs.databricks.com/api/workspace/statementexecution/executestatement + """ + + # Numeric types + BYTE = "byte" + SHORT = "short" + INT = "int" + LONG = "long" + FLOAT = "float" + DOUBLE = "double" + DECIMAL = "decimal" + + # Boolean type + BOOLEAN = "boolean" + + # Date/Time types + DATE = "date" + TIMESTAMP = "timestamp" + INTERVAL = "interval" + + # String types + CHAR = "char" + STRING = "string" + + # Binary type + BINARY = "binary" + + # Complex types + ARRAY = "array" + MAP = "map" + STRUCT = "struct" + + # Other types + NULL = "null" + USER_DEFINED_TYPE = "user_defined_type" + + +class SqlTypeConverter: + """ + Utility class for converting SQL types to Python types. + Based on the types supported by the Databricks SDK. + """ + + # SQL type to conversion function mapping + # TODO: complex types + TYPE_MAPPING: Dict[str, Callable] = { + # Numeric types + SqlType.BYTE: lambda v: int(v), + SqlType.SHORT: lambda v: int(v), + SqlType.INT: lambda v: int(v), + SqlType.LONG: lambda v: int(v), + SqlType.FLOAT: lambda v: float(v), + SqlType.DOUBLE: lambda v: float(v), + SqlType.DECIMAL: _convert_decimal, + # Boolean type + SqlType.BOOLEAN: lambda v: v.lower() in ("true", "t", "1", "yes", "y"), + # Date/Time types + SqlType.DATE: lambda v: datetime.date.fromisoformat(v), + SqlType.TIMESTAMP: lambda v: parser.parse(v), + SqlType.INTERVAL: lambda v: v, # Keep as string for now + # String types - no conversion needed + SqlType.CHAR: lambda v: v, + SqlType.STRING: lambda v: v, + # Binary type + SqlType.BINARY: lambda v: bytes.fromhex(v), + # Other types + SqlType.NULL: lambda v: None, + # Complex types and user-defined types return as-is + SqlType.USER_DEFINED_TYPE: lambda v: v, + } + + @staticmethod + def convert_value( + value: str, + sql_type: str, + **kwargs, + ) -> object: + """ + Convert a string value to the appropriate Python type based on SQL type. + + Args: + value: The string value to convert + sql_type: The SQL type (e.g., 'int', 'decimal') + **kwargs: Additional keyword arguments for the conversion function + + Returns: + The converted value in the appropriate Python type + """ + + sql_type = sql_type.lower().strip() + + if sql_type not in SqlTypeConverter.TYPE_MAPPING: + return value + + converter_func = SqlTypeConverter.TYPE_MAPPING[sql_type] + try: + if sql_type == SqlType.DECIMAL: + precision = kwargs.get("precision", None) + scale = kwargs.get("scale", None) + return converter_func(value, precision, scale) + else: + return converter_func(value) + except (ValueError, TypeError, decimal.InvalidOperation) as e: + logger.warning(f"Error converting value '{value}' to {sql_type}: {e}") + return value diff --git a/src/databricks/sql/backend/sea/utils/filters.py b/src/databricks/sql/backend/sea/utils/filters.py index 43db35984..0bdb23b03 100644 --- a/src/databricks/sql/backend/sea/utils/filters.py +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -17,7 +17,7 @@ ) if TYPE_CHECKING: - from databricks.sql.result_set import SeaResultSet + from databricks.sql.backend.sea.result_set import SeaResultSet from databricks.sql.backend.types import ExecuteResponse @@ -70,16 +70,20 @@ def _filter_sea_result_set( result_data = ResultData(data=filtered_rows, external_links=None) from databricks.sql.backend.sea.backend import SeaDatabricksClient - from databricks.sql.result_set import SeaResultSet + from databricks.sql.backend.sea.result_set import SeaResultSet # Create a new SeaResultSet with the filtered data + manifest = result_set.manifest + manifest.total_row_count = len(filtered_rows) + filtered_result_set = SeaResultSet( connection=result_set.connection, execute_response=execute_response, sea_client=cast(SeaDatabricksClient, result_set.backend), + result_data=result_data, + manifest=manifest, buffer_size_bytes=result_set.buffer_size_bytes, arraysize=result_set.arraysize, - result_data=result_data, ) return filtered_result_set diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 16a664e78..b404b1669 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -6,8 +6,10 @@ import time import threading from typing import List, Optional, Union, Any, TYPE_CHECKING +from uuid import UUID from databricks.sql.result_set import ThriftResultSet +from databricks.sql.telemetry.models.event import StatementType if TYPE_CHECKING: @@ -43,11 +45,10 @@ ) from databricks.sql.utils import ( - ResultSetQueueFactory, + ThriftResultSetQueueFactory, _bound, RequestErrorInfo, NoRetryReason, - ResultSetQueueFactory, convert_arrow_based_set_to_arrow_table, convert_decimals_in_arrow_table, convert_column_based_set_to_arrow_table, @@ -166,6 +167,7 @@ def __init__( self._use_arrow_native_complex_types = kwargs.get( "_use_arrow_native_complex_types", True ) + self._use_arrow_native_decimals = kwargs.get("_use_arrow_native_decimals", True) self._use_arrow_native_timestamps = kwargs.get( "_use_arrow_native_timestamps", True @@ -788,7 +790,7 @@ def _results_message_to_execute_response(self, resp, operation_state): direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation - is_direct_results = ( + has_more_rows = ( (not direct_results) or (not direct_results.resultSet) or direct_results.resultSet.hasMoreRows @@ -829,7 +831,7 @@ def _results_message_to_execute_response(self, resp, operation_state): result_format=t_result_set_metadata_resp.resultFormat, ) - return execute_response, is_direct_results + return execute_response, has_more_rows def get_execution_result( self, command_id: CommandId, cursor: Cursor @@ -874,7 +876,7 @@ def get_execution_result( lz4_compressed = t_result_set_metadata_resp.lz4Compressed is_staging_operation = t_result_set_metadata_resp.isStagingOperation - is_direct_results = resp.hasMoreRows + has_more_rows = resp.hasMoreRows status = CommandState.from_thrift_state(resp.status) or CommandState.RUNNING @@ -899,7 +901,7 @@ def get_execution_result( t_row_set=resp.results, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + has_more_rows=has_more_rows, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -1018,7 +1020,7 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response, is_direct_results = self._handle_execute_response( + execute_response, has_more_rows = self._handle_execute_response( resp, cursor ) @@ -1036,7 +1038,7 @@ def execute_command( t_row_set=t_row_set, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + has_more_rows=has_more_rows, ) def get_catalogs( @@ -1058,9 +1060,7 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response, is_direct_results = self._handle_execute_response( - resp, cursor - ) + execute_response, has_more_rows = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1076,7 +1076,7 @@ def get_catalogs( t_row_set=t_row_set, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + has_more_rows=has_more_rows, ) def get_schemas( @@ -1104,9 +1104,7 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response, is_direct_results = self._handle_execute_response( - resp, cursor - ) + execute_response, has_more_rows = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1122,7 +1120,7 @@ def get_schemas( t_row_set=t_row_set, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + has_more_rows=has_more_rows, ) def get_tables( @@ -1154,9 +1152,7 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response, is_direct_results = self._handle_execute_response( - resp, cursor - ) + execute_response, has_more_rows = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1172,7 +1168,7 @@ def get_tables( t_row_set=t_row_set, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + has_more_rows=has_more_rows, ) def get_columns( @@ -1204,9 +1200,7 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response, is_direct_results = self._handle_execute_response( - resp, cursor - ) + execute_response, has_more_rows = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1222,7 +1216,7 @@ def get_columns( t_row_set=t_row_set, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + has_more_rows=has_more_rows, ) def _handle_execute_response(self, resp, cursor): @@ -1257,6 +1251,7 @@ def fetch_results( lz4_compressed: bool, arrow_schema_bytes, description, + chunk_id: int, use_cloud_fetch=True, ): thrift_handle = command_id.to_thrift_handle() @@ -1286,7 +1281,7 @@ def fetch_results( session_id_hex=self._session_id_hex, ) - queue = ResultSetQueueFactory.build_queue( + queue = ThriftResultSetQueueFactory.build_queue( row_set_type=resp.resultSetMetadata.resultFormat, t_row_set=resp.results, arrow_schema_bytes=arrow_schema_bytes, @@ -1294,9 +1289,16 @@ def fetch_results( lz4_compressed=lz4_compressed, description=description, ssl_options=self._ssl_options, + session_id_hex=self._session_id_hex, + statement_id=command_id.to_hex_guid(), + chunk_id=chunk_id, ) - return queue, resp.hasMoreRows + return ( + queue, + resp.hasMoreRows, + len(resp.results.resultLinks) if resp.results.resultLinks else 0, + ) def cancel_command(self, command_id: CommandId) -> None: thrift_handle = command_id.to_thrift_handle() diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index f645fc6d1..5708f5e54 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -4,6 +4,7 @@ import logging from databricks.sql.backend.utils.guid_utils import guid_to_hex_id +from databricks.sql.telemetry.models.enums import StatementType from databricks.sql.thrift_api.TCLIService import ttypes logger = logging.getLogger(__name__) @@ -418,7 +419,7 @@ class ExecuteResponse: command_id: CommandId status: CommandState - description: Optional[List[Tuple]] = None + description: List[Tuple] has_been_closed_server_side: bool = False lz4_compressed: bool = True is_staging_operation: bool = False diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 873c55a88..de53a86e9 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -298,7 +298,9 @@ def read(self) -> Optional[OAuthToken]: driver_connection_params = DriverConnectionParameters( http_path=http_path, - mode=DatabricksClientType.THRIFT, + mode=DatabricksClientType.SEA + if self.session.use_sea + else DatabricksClientType.THRIFT, host_info=HostDetails(host_url=server_hostname, port=self.session.port), auth_mech=TelemetryHelper.get_auth_mechanism(self.session.auth_provider), auth_flow=TelemetryHelper.get_auth_flow(self.session.auth_provider), diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 7e96cd323..32b698bed 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -1,7 +1,7 @@ import logging from concurrent.futures import ThreadPoolExecutor, Future -from typing import List, Union +from typing import List, Union, Tuple, Optional from databricks.sql.cloudfetch.downloader import ( ResultSetDownloadHandler, @@ -9,7 +9,7 @@ DownloadedFile, ) from databricks.sql.types import SSLOptions - +from databricks.sql.telemetry.models.event import StatementType from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink logger = logging.getLogger(__name__) @@ -22,17 +22,22 @@ def __init__( max_download_threads: int, lz4_compressed: bool, ssl_options: SSLOptions, + session_id_hex: Optional[str], + statement_id: str, + chunk_id: int, ): - self._pending_links: List[TSparkArrowResultLink] = [] - for link in links: + self._pending_links: List[Tuple[int, TSparkArrowResultLink]] = [] + self.chunk_id = chunk_id + for i, link in enumerate(links, start=chunk_id): if link.rowCount <= 0: continue logger.debug( - "ResultFileDownloadManager: adding file link, start offset {}, row count: {}".format( - link.startRowOffset, link.rowCount + "ResultFileDownloadManager: adding file link, chunk id {}, start offset {}, row count: {}".format( + i, link.startRowOffset, link.rowCount ) ) - self._pending_links.append(link) + self._pending_links.append((i, link)) + self.chunk_id += len(links) self._download_tasks: List[Future[DownloadedFile]] = [] self._max_download_threads: int = max_download_threads @@ -40,6 +45,8 @@ def __init__( self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed) self._ssl_options = ssl_options + self.session_id_hex = session_id_hex + self.statement_id = statement_id def get_next_downloaded_file( self, next_row_offset: int @@ -89,18 +96,42 @@ def _schedule_downloads(self): while (len(self._download_tasks) < self._max_download_threads) and ( len(self._pending_links) > 0 ): - link = self._pending_links.pop(0) + chunk_id, link = self._pending_links.pop(0) logger.debug( - "- start: {}, row count: {}".format(link.startRowOffset, link.rowCount) + "- chunk: {}, start: {}, row count: {}".format( + chunk_id, link.startRowOffset, link.rowCount + ) ) handler = ResultSetDownloadHandler( settings=self._downloadable_result_settings, link=link, ssl_options=self._ssl_options, + chunk_id=chunk_id, + session_id_hex=self.session_id_hex, + statement_id=self.statement_id, ) task = self._thread_pool.submit(handler.run) self._download_tasks.append(task) + def add_link(self, link: TSparkArrowResultLink): + """ + Add more links to the download manager. + + Args: + link: Link to add + """ + + if link.rowCount <= 0: + return + + logger.debug( + "ResultFileDownloadManager: adding file link, start offset {}, row count: {}".format( + link.startRowOffset, link.rowCount + ) + ) + self._pending_links.append((self.chunk_id, link)) + self.chunk_id += 1 + def _shutdown_manager(self): # Clear download handlers and shutdown the thread pool self._pending_links = [] diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 4421c4770..57047d6ff 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -1,5 +1,6 @@ import logging from dataclasses import dataclass +from typing import Optional from requests.adapters import Retry import lz4.frame @@ -8,6 +9,8 @@ from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink from databricks.sql.exc import Error from databricks.sql.types import SSLOptions +from databricks.sql.telemetry.latency_logger import log_latency +from databricks.sql.telemetry.models.event import StatementType logger = logging.getLogger(__name__) @@ -65,12 +68,19 @@ def __init__( settings: DownloadableResultSettings, link: TSparkArrowResultLink, ssl_options: SSLOptions, + chunk_id: int, + session_id_hex: Optional[str], + statement_id: str, ): self.settings = settings self.link = link self._ssl_options = ssl_options self._http_client = DatabricksHttpClient.get_instance() + self.chunk_id = chunk_id + self.session_id_hex = session_id_hex + self.statement_id = statement_id + @log_latency(StatementType.QUERY) def run(self) -> DownloadedFile: """ Download the file described in the cloud fetch link. @@ -80,8 +90,8 @@ def run(self) -> DownloadedFile: """ logger.debug( - "ResultSetDownloadHandler: starting file download, offset {}, row count {}".format( - self.link.startRowOffset, self.link.rowCount + "ResultSetDownloadHandler: starting file download, chunk id {}, offset {}, row count {}".format( + self.chunk_id, self.link.startRowOffset, self.link.rowCount ) ) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 9627c5977..3d3587cae 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -1,12 +1,11 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import List, Optional, Any, TYPE_CHECKING +from typing import List, Optional, TYPE_CHECKING, Tuple import logging import pandas - try: import pyarrow except ImportError: @@ -14,14 +13,16 @@ if TYPE_CHECKING: from databricks.sql.backend.thrift_backend import ThriftDatabricksClient - from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.client import Connection - from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.types import Row from databricks.sql.exc import RequestError, CursorAlreadyClosedError -from databricks.sql.utils import ColumnTable, ColumnQueue +from databricks.sql.utils import ( + ColumnTable, + ColumnQueue, +) from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse +from databricks.sql.telemetry.models.event import StatementType logger = logging.getLogger(__name__) @@ -42,9 +43,9 @@ def __init__( command_id: CommandId, status: CommandState, has_been_closed_server_side: bool = False, - is_direct_results: bool = False, + has_more_rows: bool = False, results_queue=None, - description=None, + description: List[Tuple] = [], is_staging_operation: bool = False, lz4_compressed: bool = False, arrow_schema_bytes: Optional[bytes] = None, @@ -60,7 +61,7 @@ def __init__( :param command_id: The command ID :param status: The command status :param has_been_closed_server_side: Whether the command has been closed on the server - :param is_direct_results: Whether the command has more rows + :param has_more_rows: Whether the command has more rows :param results_queue: The results queue :param description: column description of the results :param is_staging_operation: Whether the command is a staging operation @@ -75,7 +76,7 @@ def __init__( self.command_id = command_id self.status = status self.has_been_closed_server_side = has_been_closed_server_side - self.is_direct_results = is_direct_results + self.has_more_rows = has_more_rows self.results = results_queue self._is_staging_operation = is_staging_operation self.lz4_compressed = lz4_compressed @@ -89,6 +90,44 @@ def __iter__(self): else: break + def _convert_arrow_table(self, table): + column_names = [c[0] for c in self.description] + ResultRow = Row(*column_names) + + if self.connection.disable_pandas is True: + return [ + ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns()) + ] + + # Need to use nullable types, as otherwise type can change when there are missing values. + # See https://arrow.apache.org/docs/python/pandas.html#nullable-types + # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html + dtype_mapping = { + pyarrow.int8(): pandas.Int8Dtype(), + pyarrow.int16(): pandas.Int16Dtype(), + pyarrow.int32(): pandas.Int32Dtype(), + pyarrow.int64(): pandas.Int64Dtype(), + pyarrow.uint8(): pandas.UInt8Dtype(), + pyarrow.uint16(): pandas.UInt16Dtype(), + pyarrow.uint32(): pandas.UInt32Dtype(), + pyarrow.uint64(): pandas.UInt64Dtype(), + pyarrow.bool_(): pandas.BooleanDtype(), + pyarrow.float32(): pandas.Float32Dtype(), + pyarrow.float64(): pandas.Float64Dtype(), + pyarrow.string(): pandas.StringDtype(), + } + + # Need to rename columns, as the to_pandas function cannot handle duplicate column names + table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) + df = table_renamed.to_pandas( + types_mapper=dtype_mapping.get, + date_as_object=True, + timestamp_as_object=True, + ) + + res = df.to_numpy(na_value=None, dtype="object") + return [ResultRow(*v) for v in res] + @property def rownumber(self): return self._next_row_index @@ -98,12 +137,6 @@ def is_staging_operation(self) -> bool: """Whether this result set represents a staging operation.""" return self._is_staging_operation - # Define abstract methods that concrete implementations must implement - @abstractmethod - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - pass - @abstractmethod def fetchone(self) -> Optional[Row]: """Fetch the next row of a query result set.""" @@ -137,8 +170,11 @@ def close(self) -> None: been closed on the server for some other reason, issue a request to the server to close it. """ try: - if self.results: + if self.results is not None: self.results.close() + else: + logger.warning("result set close: queue not initialized") + if ( self.status != CommandState.CLOSED and not self.has_been_closed_server_side @@ -167,7 +203,7 @@ def __init__( t_row_set=None, max_download_threads: int = 10, ssl_options=None, - is_direct_results: bool = True, + has_more_rows: bool = True, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. @@ -182,20 +218,21 @@ def __init__( :param t_row_set: The TRowSet containing result data (if available) :param max_download_threads: Maximum number of download threads for cloud fetch :param ssl_options: SSL options for cloud fetch - :param is_direct_results: Whether there are more rows to fetch + :param has_more_rows: Whether there are more rows to fetch """ + self.num_chunks = 0 # Initialize ThriftResultSet-specific attributes self._use_cloud_fetch = use_cloud_fetch - self.is_direct_results = is_direct_results + self.has_more_rows = has_more_rows # Build the results queue if t_row_set is provided results_queue = None if t_row_set and execute_response.result_format is not None: - from databricks.sql.utils import ResultSetQueueFactory + from databricks.sql.utils import ThriftResultSetQueueFactory # Create the results queue using the provided format - results_queue = ResultSetQueueFactory.build_queue( + results_queue = ThriftResultSetQueueFactory.build_queue( row_set_type=execute_response.result_format, t_row_set=t_row_set, arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", @@ -203,7 +240,12 @@ def __init__( lz4_compressed=execute_response.lz4_compressed, description=execute_response.description, ssl_options=ssl_options, + session_id_hex=connection.get_session_id_hex(), + statement_id=execute_response.command_id.to_hex_guid(), + chunk_id=self.num_chunks, ) + if t_row_set.resultLinks: + self.num_chunks += len(t_row_set.resultLinks) # Call parent constructor with common attributes super().__init__( @@ -214,7 +256,7 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, - is_direct_results=is_direct_results, + has_more_rows=has_more_rows, results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, @@ -227,7 +269,7 @@ def __init__( self._fill_results_buffer() def _fill_results_buffer(self): - results, is_direct_results = self.backend.fetch_results( + results, has_more_rows, result_links_count = self.backend.fetch_results( command_id=self.command_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, @@ -236,9 +278,11 @@ def _fill_results_buffer(self): arrow_schema_bytes=self._arrow_schema_bytes, description=self.description, use_cloud_fetch=self._use_cloud_fetch, + chunk_id=self.num_chunks, ) self.results = results - self.is_direct_results = is_direct_results + self.has_more_rows = has_more_rows + self.num_chunks += result_links_count def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -252,44 +296,6 @@ def _convert_columnar_table(self, table): return result - def _convert_arrow_table(self, table): - column_names = [c[0] for c in self.description] - ResultRow = Row(*column_names) - - if self.connection.disable_pandas is True: - return [ - ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns()) - ] - - # Need to use nullable types, as otherwise type can change when there are missing values. - # See https://arrow.apache.org/docs/python/pandas.html#nullable-types - # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html - dtype_mapping = { - pyarrow.int8(): pandas.Int8Dtype(), - pyarrow.int16(): pandas.Int16Dtype(), - pyarrow.int32(): pandas.Int32Dtype(), - pyarrow.int64(): pandas.Int64Dtype(), - pyarrow.uint8(): pandas.UInt8Dtype(), - pyarrow.uint16(): pandas.UInt16Dtype(), - pyarrow.uint32(): pandas.UInt32Dtype(), - pyarrow.uint64(): pandas.UInt64Dtype(), - pyarrow.bool_(): pandas.BooleanDtype(), - pyarrow.float32(): pandas.Float32Dtype(), - pyarrow.float64(): pandas.Float64Dtype(), - pyarrow.string(): pandas.StringDtype(), - } - - # Need to rename columns, as the to_pandas function cannot handle duplicate column names - table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) - df = table_renamed.to_pandas( - types_mapper=dtype_mapping.get, - date_as_object=True, - timestamp_as_object=True, - ) - - res = df.to_numpy(na_value=None, dtype="object") - return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2) -> "ColumnTable": """ Function to merge / combining the columnar results into a single result @@ -323,7 +329,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.is_direct_results + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -348,7 +354,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.is_direct_results + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -363,7 +369,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows partial_result_chunks = [results] - while not self.has_been_closed_server_side and self.is_direct_results: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -389,7 +395,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.is_direct_results: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) @@ -448,82 +454,3 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] - - -class SeaResultSet(ResultSet): - """ResultSet implementation for SEA backend.""" - - def __init__( - self, - connection: Connection, - execute_response: ExecuteResponse, - sea_client: SeaDatabricksClient, - buffer_size_bytes: int = 104857600, - arraysize: int = 10000, - result_data=None, - manifest=None, - ): - """ - Initialize a SeaResultSet with the response from a SEA query execution. - - Args: - connection: The parent connection - execute_response: Response from the execute command - sea_client: The SeaDatabricksClient instance for direct access - buffer_size_bytes: Buffer size for fetching results - arraysize: Default number of rows to fetch - result_data: Result data from SEA response (optional) - manifest: Manifest from SEA response (optional) - """ - - super().__init__( - connection=connection, - backend=sea_client, - arraysize=arraysize, - buffer_size_bytes=buffer_size_bytes, - command_id=execute_response.command_id, - status=execute_response.status, - has_been_closed_server_side=execute_response.has_been_closed_server_side, - description=execute_response.description, - is_staging_operation=execute_response.is_staging_operation, - lz4_compressed=execute_response.lz4_compressed, - arrow_schema_bytes=execute_response.arrow_schema_bytes, - ) - - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - raise NotImplementedError( - "_fill_results_buffer is not implemented for SEA backend" - ) - - def fetchone(self) -> Optional[Row]: - """ - Fetch the next row of a query result set, returning a single sequence, - or None when no more data is available. - """ - - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchmany(self, size: Optional[int] = None) -> List[Row]: - """ - Fetch the next set of rows of a query result, returning a list of rows. - - An empty sequence is returned when no more rows are available. - """ - - raise NotImplementedError("fetchmany is not implemented for SEA backend") - - def fetchall(self) -> List[Row]: - """ - Fetch all (remaining) rows of a query result, returning them as a list of rows. - """ - - raise NotImplementedError("fetchall is not implemented for SEA backend") - - def fetchmany_arrow(self, size: int) -> Any: - """Fetch the next set of rows as an Arrow table.""" - raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") - - def fetchall_arrow(self) -> Any: - """Fetch all remaining rows as an Arrow table.""" - raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index cc60a61b5..f1bc35bee 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -65,7 +65,7 @@ def __init__( base_headers = [("User-Agent", self.useragent_header)] all_headers = (http_headers or []) + base_headers - self._ssl_options = SSLOptions( + self.ssl_options = SSLOptions( # Double negation is generally a bad thing, but we have to keep backward compatibility tls_verify=not kwargs.get( "_tls_no_verify", False @@ -98,10 +98,10 @@ def _create_backend( kwargs: dict, ) -> DatabricksClient: """Create and return the appropriate backend client.""" - use_sea = kwargs.get("use_sea", False) + self.use_sea = kwargs.get("use_sea", False) databricks_client_class: Type[DatabricksClient] - if use_sea: + if self.use_sea: logger.debug("Creating SEA backend client") databricks_client_class = SeaDatabricksClient else: @@ -114,7 +114,7 @@ def _create_backend( "http_path": http_path, "http_headers": all_headers, "auth_provider": auth_provider, - "ssl_options": self._ssl_options, + "ssl_options": self.ssl_options, "_use_arrow_native_complex_types": _use_arrow_native_complex_types, **kwargs, } diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index 0b0c564da..12cacd851 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -7,8 +7,6 @@ SqlExecutionEvent, ) from databricks.sql.telemetry.models.enums import ExecutionResultFormat, StatementType -from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue -from uuid import UUID logger = logging.getLogger(__name__) @@ -36,12 +34,15 @@ def get_statement_id(self): def get_is_compressed(self): pass - def get_execution_result(self): + def get_execution_result_format(self): pass def get_retry_count(self): pass + def get_chunk_id(self): + pass + class CursorExtractor(TelemetryExtractor): """ @@ -60,10 +61,12 @@ def get_session_id_hex(self) -> Optional[str]: def get_is_compressed(self) -> bool: return self.connection.lz4_compression - def get_execution_result(self) -> ExecutionResultFormat: + def get_execution_result_format(self) -> ExecutionResultFormat: if self.active_result_set is None: return ExecutionResultFormat.FORMAT_UNSPECIFIED + from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue + if isinstance(self.active_result_set.results, ColumnQueue): return ExecutionResultFormat.COLUMNAR_INLINE elif isinstance(self.active_result_set.results, CloudFetchQueue): @@ -73,49 +76,37 @@ def get_execution_result(self) -> ExecutionResultFormat: return ExecutionResultFormat.FORMAT_UNSPECIFIED def get_retry_count(self) -> int: - if ( - hasattr(self.thrift_backend, "retry_policy") - and self.thrift_backend.retry_policy - ): - return len(self.thrift_backend.retry_policy.history) + if hasattr(self.backend, "retry_policy") and self.backend.retry_policy: + return len(self.backend.retry_policy.history) return 0 + def get_chunk_id(self): + return None -class ResultSetExtractor(TelemetryExtractor): - """ - Telemetry extractor specialized for ResultSet objects. - Extracts telemetry information from database result set objects, including - operation IDs, session information, compression settings, and result formats. +class ResultSetDownloadHandlerExtractor(TelemetryExtractor): + """ + Telemetry extractor specialized for ResultSetDownloadHandler objects. """ - - def get_statement_id(self) -> Optional[str]: - if self.command_id: - return str(UUID(bytes=self.command_id.operationId.guid)) - return None def get_session_id_hex(self) -> Optional[str]: - return self.connection.get_session_id_hex() + return self._obj.session_id_hex + + def get_statement_id(self) -> Optional[str]: + return self._obj.statement_id def get_is_compressed(self) -> bool: - return self.lz4_compressed + return self._obj.settings.is_lz4_compressed - def get_execution_result(self) -> ExecutionResultFormat: - if isinstance(self.results, ColumnQueue): - return ExecutionResultFormat.COLUMNAR_INLINE - elif isinstance(self.results, CloudFetchQueue): - return ExecutionResultFormat.EXTERNAL_LINKS - elif isinstance(self.results, ArrowQueue): - return ExecutionResultFormat.INLINE_ARROW - return ExecutionResultFormat.FORMAT_UNSPECIFIED + def get_execution_result_format(self) -> ExecutionResultFormat: + return ExecutionResultFormat.EXTERNAL_LINKS - def get_retry_count(self) -> int: - if ( - hasattr(self.thrift_backend, "retry_policy") - and self.thrift_backend.retry_policy - ): - return len(self.thrift_backend.retry_policy.history) - return 0 + def get_retry_count(self) -> Optional[int]: + # standard requests and urllib3 libraries don't expose retry count + return None + + def get_chunk_id(self) -> Optional[int]: + return self._obj.chunk_id def get_extractor(obj): @@ -126,19 +117,19 @@ def get_extractor(obj): that can extract telemetry information from that object type. Args: - obj: The object to create an extractor for. Can be a Cursor, ResultSet, - or any other object. + obj: The object to create an extractor for. Can be a Cursor, + ResultSetDownloadHandler, or any other object. Returns: TelemetryExtractor: A specialized extractor instance: - CursorExtractor for Cursor objects - - ResultSetExtractor for ResultSet objects + - ResultSetDownloadHandlerExtractor for ResultSetDownloadHandler objects - None for all other objects """ if obj.__class__.__name__ == "Cursor": return CursorExtractor(obj) - elif obj.__class__.__name__ == "ResultSet": - return ResultSetExtractor(obj) + elif obj.__class__.__name__ == "ResultSetDownloadHandler": + return ResultSetDownloadHandlerExtractor(obj) else: logger.debug("No extractor found for %s", obj.__class__.__name__) return None @@ -162,7 +153,7 @@ def log_latency(statement_type: StatementType = StatementType.NONE): statement_type (StatementType): The type of SQL statement being executed. Usage: - @log_latency(StatementType.SQL) + @log_latency(StatementType.QUERY) def execute(self, query): # Method implementation pass @@ -204,8 +195,11 @@ def _safe_call(func_to_call): sql_exec_event = SqlExecutionEvent( statement_type=statement_type, is_compressed=_safe_call(extractor.get_is_compressed), - execution_result=_safe_call(extractor.get_execution_result), + execution_result=_safe_call( + extractor.get_execution_result_format + ), retry_count=_safe_call(extractor.get_retry_count), + chunk_id=_safe_call(extractor.get_chunk_id), ) telemetry_client = TelemetryClientFactory.get_telemetry_client( diff --git a/src/databricks/sql/telemetry/models/event.py b/src/databricks/sql/telemetry/models/event.py index a155c7597..c7f9d9d17 100644 --- a/src/databricks/sql/telemetry/models/event.py +++ b/src/databricks/sql/telemetry/models/event.py @@ -122,12 +122,14 @@ class SqlExecutionEvent(JsonSerializableMixin): is_compressed (bool): Whether the result is compressed execution_result (ExecutionResultFormat): Format of the execution result retry_count (int): Number of retry attempts made + chunk_id (int): ID of the chunk if applicable """ statement_type: StatementType is_compressed: bool execution_result: ExecutionResultFormat - retry_count: int + retry_count: Optional[int] + chunk_id: Optional[int] @dataclass diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 7e8a4fa0c..4617f7de6 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -1,4 +1,5 @@ from __future__ import annotations +from typing import Dict, List, Optional, Union from dateutil import parser import datetime @@ -18,7 +19,7 @@ except ImportError: pyarrow = None -from databricks.sql import OperationalError, exc +from databricks.sql import OperationalError from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager from databricks.sql.thrift_api.TCLIService.ttypes import ( TRowSet, @@ -27,7 +28,7 @@ ) from databricks.sql.types import SSLOptions from databricks.sql.backend.types import CommandId - +from databricks.sql.telemetry.models.event import StatementType from databricks.sql.parameters.native import ParameterStructure, TDbsqlParameter import logging @@ -52,7 +53,7 @@ def close(self): pass -class ResultSetQueueFactory(ABC): +class ThriftResultSetQueueFactory(ABC): @staticmethod def build_queue( row_set_type: TSparkRowSetType, @@ -60,11 +61,14 @@ def build_queue( arrow_schema_bytes: bytes, max_download_threads: int, ssl_options: SSLOptions, + session_id_hex: Optional[str], + statement_id: str, + chunk_id: int, lz4_compressed: bool = True, - description: Optional[List[Tuple]] = None, + description: List[Tuple] = [], ) -> ResultSetQueue: """ - Factory method to build a result set queue. + Factory method to build a result set queue for Thrift backend. Args: row_set_type (enum): Row set type (Arrow, Column, or URL). @@ -98,7 +102,7 @@ def build_queue( return ColumnQueue(ColumnTable(converted_column_table, column_names)) elif row_set_type == TSparkRowSetType.URL_BASED_SET: - return CloudFetchQueue( + return ThriftCloudFetchQueue( schema_bytes=arrow_schema_bytes, start_row_offset=t_row_set.startRowOffset, result_links=t_row_set.resultLinks, @@ -106,6 +110,9 @@ def build_queue( description=description, max_download_threads=max_download_threads, ssl_options=ssl_options, + session_id_hex=session_id_hex, + statement_id=statement_id, + chunk_id=chunk_id, ) else: raise AssertionError("Row set type is not valid") @@ -207,66 +214,61 @@ def close(self): return -class CloudFetchQueue(ResultSetQueue): +class CloudFetchQueue(ResultSetQueue, ABC): + """Base class for cloud fetch queues that handle EXTERNAL_LINKS disposition with ARROW format.""" + def __init__( self, - schema_bytes, max_download_threads: int, ssl_options: SSLOptions, - start_row_offset: int = 0, - result_links: Optional[List[TSparkArrowResultLink]] = None, + session_id_hex: Optional[str], + statement_id: str, + chunk_id: int, + schema_bytes: Optional[bytes] = None, lz4_compressed: bool = True, - description: Optional[List[Tuple]] = None, + description: List[Tuple] = [], ): """ - A queue-like wrapper over CloudFetch arrow batches. + Initialize the base CloudFetchQueue. - Attributes: - schema_bytes (bytes): Table schema in bytes. - max_download_threads (int): Maximum number of downloader thread pool threads. - start_row_offset (int): The offset of the first row of the cloud fetch links. - result_links (List[TSparkArrowResultLink]): Links containing the downloadable URL and metadata. - lz4_compressed (bool): Whether the files are lz4 compressed. - description (List[List[Any]]): Hive table schema description. + Args: + max_download_threads: Maximum number of download threads + ssl_options: SSL options for downloads + schema_bytes: Arrow schema bytes + lz4_compressed: Whether the data is LZ4 compressed + description: Column descriptions """ self.schema_bytes = schema_bytes self.max_download_threads = max_download_threads - self.start_row_index = start_row_offset - self.result_links = result_links self.lz4_compressed = lz4_compressed self.description = description self._ssl_options = ssl_options + self.session_id_hex = session_id_hex + self.statement_id = statement_id + self.chunk_id = chunk_id - logger.debug( - "Initialize CloudFetch loader, row set start offset: {}, file list:".format( - start_row_offset - ) - ) - if result_links is not None: - for result_link in result_links: - logger.debug( - "- start row offset: {}, row count: {}".format( - result_link.startRowOffset, result_link.rowCount - ) - ) + # Table state + self.table = None + self.table_row_index = 0 + + # Initialize download manager self.download_manager = ResultFileDownloadManager( - links=result_links or [], - max_download_threads=self.max_download_threads, - lz4_compressed=self.lz4_compressed, - ssl_options=self._ssl_options, + links=[], + max_download_threads=max_download_threads, + lz4_compressed=lz4_compressed, + ssl_options=ssl_options, + session_id_hex=session_id_hex, + statement_id=statement_id, + chunk_id=chunk_id, ) - self.table = self._create_next_table() - self.table_row_index = 0 - def next_n_rows(self, num_rows: int) -> "pyarrow.Table": """ Get up to the next n rows of the cloud fetch Arrow dataframes. Args: num_rows (int): Number of rows to retrieve. - Returns: pyarrow.Table """ @@ -317,21 +319,14 @@ def remaining_rows(self) -> "pyarrow.Table": self.table_row_index = 0 return pyarrow.concat_tables(partial_result_chunks, use_threads=True) - def _create_next_table(self) -> Union["pyarrow.Table", None]: - logger.debug( - "CloudFetchQueue: Trying to get downloaded file for row {}".format( - self.start_row_index - ) - ) + def _create_table_at_offset(self, offset: int) -> Union["pyarrow.Table", None]: + """Create next table at the given row offset""" + # Create next table by retrieving the logical next downloaded file, or return None to signal end of queue - downloaded_file = self.download_manager.get_next_downloaded_file( - self.start_row_index - ) + downloaded_file = self.download_manager.get_next_downloaded_file(offset) if not downloaded_file: logger.debug( - "CloudFetchQueue: Cannot find downloaded file for row {}".format( - self.start_row_index - ) + "CloudFetchQueue: Cannot find downloaded file for row {}".format(offset) ) # None signals no more Arrow tables can be built from the remaining handlers if any remain return None @@ -346,24 +341,103 @@ def _create_next_table(self) -> Union["pyarrow.Table", None]: # At this point, whether the file has extraneous rows or not, the arrow table should have the correct num rows assert downloaded_file.row_count == arrow_table.num_rows - self.start_row_index += arrow_table.num_rows - - logger.debug( - "CloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}".format( - arrow_table.num_rows, self.start_row_index - ) - ) return arrow_table + @abstractmethod + def _create_next_table(self) -> Union["pyarrow.Table", None]: + """Create next table by retrieving the logical next downloaded file.""" + pass + def _create_empty_table(self) -> "pyarrow.Table": - # Create a 0-row table with just the schema bytes + """Create a 0-row table with just the schema bytes.""" + if not self.schema_bytes: + return pyarrow.Table.from_pydict({}) return create_arrow_table_from_arrow_file(self.schema_bytes, self.description) def close(self): self.download_manager._shutdown_manager() +class ThriftCloudFetchQueue(CloudFetchQueue): + """Queue implementation for EXTERNAL_LINKS disposition with ARROW format for Thrift backend.""" + + def __init__( + self, + schema_bytes, + max_download_threads: int, + ssl_options: SSLOptions, + session_id_hex: Optional[str], + statement_id: str, + chunk_id: int, + start_row_offset: int = 0, + result_links: Optional[List[TSparkArrowResultLink]] = None, + lz4_compressed: bool = True, + description: List[Tuple] = [], + ): + """ + Initialize the Thrift CloudFetchQueue. + + Args: + schema_bytes: Table schema in bytes + max_download_threads: Maximum number of downloader thread pool threads + ssl_options: SSL options for downloads + start_row_offset: The offset of the first row of the cloud fetch links + result_links: Links containing the downloadable URL and metadata + lz4_compressed: Whether the files are lz4 compressed + description: Hive table schema description + """ + super().__init__( + max_download_threads=max_download_threads, + ssl_options=ssl_options, + schema_bytes=schema_bytes, + lz4_compressed=lz4_compressed, + description=description, + session_id_hex=session_id_hex, + statement_id=statement_id, + chunk_id=chunk_id, + ) + + self.start_row_index = start_row_offset + self.result_links = result_links or [] + self.session_id_hex = session_id_hex + self.statement_id = statement_id + self.chunk_id = chunk_id + + logger.debug( + "Initialize CloudFetch loader, row set start offset: {}, file list:".format( + start_row_offset + ) + ) + if self.result_links: + for result_link in self.result_links: + logger.debug( + "- start row offset: {}, row count: {}".format( + result_link.startRowOffset, result_link.rowCount + ) + ) + self.download_manager.add_link(result_link) + + # Initialize table and position + self.table = self._create_next_table() + + def _create_next_table(self) -> Union["pyarrow.Table", None]: + logger.debug( + "ThriftCloudFetchQueue: Trying to get downloaded file for row {}".format( + self.start_row_index + ) + ) + arrow_table = self._create_table_at_offset(self.start_row_index) + if arrow_table: + self.start_row_index += arrow_table.num_rows + logger.debug( + "ThriftCloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}".format( + arrow_table.num_rows, self.start_row_index + ) + ) + return arrow_table + + def _bound(min_x, max_x, x): """Bound x by [min_x, max_x] diff --git a/tests/e2e/common/large_queries_mixin.py b/tests/e2e/common/large_queries_mixin.py index 1181ef154..aeeb67974 100644 --- a/tests/e2e/common/large_queries_mixin.py +++ b/tests/e2e/common/large_queries_mixin.py @@ -2,6 +2,8 @@ import math import time +import pytest + log = logging.getLogger(__name__) @@ -42,7 +44,14 @@ def fetch_rows(self, cursor, row_count, fetchmany_size): + "assuming 10K fetch size." ) - def test_query_with_large_wide_result_set(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_query_with_large_wide_result_set(self, extra_params): resultSize = 300 * 1000 * 1000 # 300 MB width = 8192 # B rows = resultSize // width @@ -52,7 +61,7 @@ def test_query_with_large_wide_result_set(self): fetchmany_size = 10 * 1024 * 1024 // width # This is used by PyHive tests to determine the buffer size self.arraysize = 1000 - with self.cursor() as cursor: + with self.cursor(extra_params) as cursor: for lz4_compression in [False, True]: cursor.connection.lz4_compression = lz4_compression uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)]) @@ -68,7 +77,14 @@ def test_query_with_large_wide_result_set(self): assert row[0] == row_id # Verify no rows are dropped in the middle. assert len(row[1]) == 36 - def test_query_with_large_narrow_result_set(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_query_with_large_narrow_result_set(self, extra_params): resultSize = 300 * 1000 * 1000 # 300 MB width = 8 # sizeof(long) rows = resultSize / width @@ -77,12 +93,19 @@ def test_query_with_large_narrow_result_set(self): fetchmany_size = 10 * 1024 * 1024 // width # This is used by PyHive tests to determine the buffer size self.arraysize = 10000000 - with self.cursor() as cursor: + with self.cursor(extra_params) as cursor: cursor.execute("SELECT * FROM RANGE({rows})".format(rows=rows)) for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)): assert row[0] == row_id - def test_long_running_query(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_long_running_query(self, extra_params): """Incrementally increase query size until it takes at least 3 minutes, and asserts that the query completes successfully. """ @@ -92,7 +115,7 @@ def test_long_running_query(self): duration = -1 scale0 = 10000 scale_factor = 1 - with self.cursor() as cursor: + with self.cursor(extra_params) as cursor: while duration < min_duration: assert scale_factor < 1024, "Detected infinite loop" start = time.time() diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py index 096247a42..3eb1745ab 100755 --- a/tests/e2e/common/retry_test_mixins.py +++ b/tests/e2e/common/retry_test_mixins.py @@ -343,7 +343,9 @@ def test_retry_abort_close_operation_on_404(self, caplog): ) @patch("databricks.sql.telemetry.telemetry_client.TelemetryClient._send_telemetry") - def test_retry_max_redirects_raises_too_many_redirects_exception(self, mock_send_telemetry): + def test_retry_max_redirects_raises_too_many_redirects_exception( + self, mock_send_telemetry + ): """GIVEN the connector is configured with a custom max_redirects WHEN the DatabricksRetryPolicy is created THEN the connector raises a MaxRedirectsError if that number is exceeded @@ -368,7 +370,9 @@ def test_retry_max_redirects_raises_too_many_redirects_exception(self, mock_send assert mock_obj.return_value.getresponse.call_count == expected_call_count @patch("databricks.sql.telemetry.telemetry_client.TelemetryClient._send_telemetry") - def test_retry_max_redirects_unset_doesnt_redirect_forever(self, mock_send_telemetry): + def test_retry_max_redirects_unset_doesnt_redirect_forever( + self, mock_send_telemetry + ): """GIVEN the connector is configured without a custom max_redirects WHEN the DatabricksRetryPolicy is used THEN the connector raises a MaxRedirectsError if that number is exceeded diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 8f15bccc6..3fa87b1af 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -182,10 +182,19 @@ def test_cloud_fetch(self): class TestPySQLAsyncQueriesSuite(PySQLPytestTestCase): - def test_execute_async__long_running(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) + def test_execute_async__long_running(self, extra_params): long_running_query = "SELECT COUNT(*) FROM RANGE(10000 * 16) x JOIN RANGE(10000) y ON FROM_UNIXTIME(x.id * y.id, 'yyyy-MM-dd') LIKE '%not%a%date%'" - with self.cursor() as cursor: + with self.cursor(extra_params) as cursor: cursor.execute_async(long_running_query) ## Polling after every POLLING_INTERVAL seconds @@ -198,10 +207,21 @@ def test_execute_async__long_running(self): assert result[0].asDict() == {"count(1)": 0} - def test_execute_async__small_result(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_execute_async__small_result(self, extra_params): small_result_query = "SELECT 1" - with self.cursor() as cursor: + with self.cursor(extra_params) as cursor: cursor.execute_async(small_result_query) ## Fake sleep for 5 secs @@ -217,7 +237,16 @@ def test_execute_async__small_result(self): assert result[0].asDict() == {"1": 1} - def test_execute_async__large_result(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) + def test_execute_async__large_result(self, extra_params): x_dimension = 1000 y_dimension = 1000 large_result_query = f""" @@ -231,7 +260,7 @@ def test_execute_async__large_result(self): RANGE({y_dimension}) y """ - with self.cursor() as cursor: + with self.cursor(extra_params) as cursor: cursor.execute_async(large_result_query) ## Fake sleep for 5 secs @@ -330,8 +359,22 @@ def test_incorrect_query_throws_exception(self): cursor.execute("CREATE TABLE IF NOT EXISTS TABLE table_234234234") assert "table_234234234" in str(cm.value) - def test_create_table_will_return_empty_result_set(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + { + "use_sea": True, + }, + ], + ) + def test_create_table_will_return_empty_result_set(self, extra_params): + with self.cursor(extra_params) as cursor: table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) try: cursor.execute( @@ -529,10 +572,24 @@ def test_get_catalogs(self): ] @skipUnless(pysql_supports_arrow(), "arrow test need arrow support") - def test_get_arrow(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + { + "use_sea": True, + }, + ], + ) + def test_get_arrow(self, extra_params): # These tests are quite light weight as the arrow fetch methods are used internally # by everything else - with self.cursor({}) as cursor: + with self.cursor(extra_params) as cursor: cursor.execute("SELECT * FROM range(10)") table_1 = cursor.fetchmany_arrow(1).to_pydict() assert table_1 == OrderedDict([("id", [0])]) @@ -540,9 +597,20 @@ def test_get_arrow(self): table_2 = cursor.fetchall_arrow().to_pydict() assert table_2 == OrderedDict([("id", [1, 2, 3, 4, 5, 6, 7, 8, 9])]) - def test_unicode(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_unicode(self, extra_params): unicode_str = "数据砖" - with self.cursor({}) as cursor: + with self.cursor(extra_params) as cursor: cursor.execute("SELECT '{}'".format(unicode_str)) results = cursor.fetchall() assert len(results) == 1 and len(results[0]) == 1 @@ -580,8 +648,22 @@ def execute_really_long_query(): assert len(cursor.fetchall()) == 3 @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_can_execute_command_after_failure(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + { + "use_sea": True, + }, + ], + ) + def test_can_execute_command_after_failure(self, extra_params): + with self.cursor(extra_params) as cursor: with pytest.raises(DatabaseError): cursor.execute("this is a sytnax error") @@ -591,8 +673,22 @@ def test_can_execute_command_after_failure(self): self.assertEqualRowValues(res, [[1]]) @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_can_execute_command_after_success(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + { + "use_sea": True, + }, + ], + ) + def test_can_execute_command_after_success(self, extra_params): + with self.cursor(extra_params) as cursor: cursor.execute("SELECT 1;") cursor.execute("SELECT 2;") @@ -604,8 +700,22 @@ def generate_multi_row_query(self): return query @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_fetchone(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + { + "use_sea": True, + }, + ], + ) + def test_fetchone(self, extra_params): + with self.cursor(extra_params) as cursor: query = self.generate_multi_row_query() cursor.execute(query) @@ -616,8 +726,19 @@ def test_fetchone(self): assert cursor.fetchone() == None @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_fetchall(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_fetchall(self, extra_params): + with self.cursor(extra_params) as cursor: query = self.generate_multi_row_query() cursor.execute(query) @@ -626,8 +747,22 @@ def test_fetchall(self): assert cursor.fetchone() == None @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_fetchmany_when_stride_fits(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + { + "use_sea": True, + }, + ], + ) + def test_fetchmany_when_stride_fits(self, extra_params): + with self.cursor(extra_params) as cursor: query = "SELECT * FROM range(4)" cursor.execute(query) @@ -635,8 +770,22 @@ def test_fetchmany_when_stride_fits(self): self.assertEqualRowValues(cursor.fetchmany(2), [[2], [3]]) @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_fetchmany_in_excess(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + { + "use_sea": True, + }, + ], + ) + def test_fetchmany_in_excess(self, extra_params): + with self.cursor(extra_params) as cursor: query = "SELECT * FROM range(4)" cursor.execute(query) @@ -644,8 +793,22 @@ def test_fetchmany_in_excess(self): self.assertEqualRowValues(cursor.fetchmany(3), [[3]]) @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_iterator_api(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + { + "use_sea": True, + }, + ], + ) + def test_iterator_api(self, extra_params): + with self.cursor(extra_params) as cursor: query = "SELECT * FROM range(4)" cursor.execute(query) @@ -718,8 +881,24 @@ def test_timestamps_arrow(self): ), "timestamp {} did not match {}".format(timestamp, expected) @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") - def test_multi_timestamps_arrow(self): - with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + { + "use_sea": True, + }, + ], + ) + def test_multi_timestamps_arrow(self, extra_params): + with self.cursor( + {"session_configuration": {"ansi_mode": False}, **extra_params} + ) as cursor: query, expected = self.multi_query() expected = [ [self.maybe_add_timezone_to_timestamp(ts) for ts in row] diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 520a0f377..4271f0d7d 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -46,7 +46,7 @@ def new(cls): is_staging_operation=False, command_id=None, has_been_closed_server_side=True, - is_direct_results=True, + has_more_rows=True, lz4_compressed=True, arrow_schema_bytes=b"schema", ) @@ -109,11 +109,12 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): mock_execute_response.has_been_closed_server_side = closed mock_execute_response.is_staging_operation = False mock_execute_response.command_id = Mock(spec=CommandId) + mock_execute_response.description = [] # Mock the backend that will be used mock_backend = Mock(spec=ThriftDatabricksClient) mock_backend.staging_allowed_local_path = None - mock_backend.fetch_results.return_value = (Mock(), False) + mock_backend.fetch_results.return_value = (Mock(), False, 0) # Configure the decorator's mock to return our specific mock_backend mock_thrift_client_class.return_value = mock_backend @@ -138,7 +139,6 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): connection.close() # Verify the close logic worked: - # 1. has_been_closed_server_side should always be True after close() assert real_result_set.has_been_closed_server_side is True # 2. op_state should always be CLOSED after close() @@ -180,7 +180,7 @@ def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() mock_backend = Mock() mock_results = Mock() - mock_backend.fetch_results.return_value = (Mock(), False) + mock_backend.fetch_results.return_value = (Mock(), False, 0) result_set = ThriftResultSet( connection=mock_connection, @@ -210,7 +210,7 @@ def test_closing_result_set_hard_closes_commands(self): mock_session.open = True type(mock_connection).session = PropertyMock(return_value=mock_session) - mock_thrift_backend.fetch_results.return_value = (Mock(), False) + mock_thrift_backend.fetch_results.return_value = (Mock(), False, 0) result_set = ThriftResultSet( mock_connection, mock_results_response, @@ -231,12 +231,6 @@ def test_executing_multiple_commands_uses_the_most_recent_command(self): for mock_rs in mock_result_sets: mock_rs.is_staging_operation = False - mock_backend = ThriftDatabricksClientMockFactory.new() - mock_backend.execute_command.side_effect = mock_result_sets - # Set is_staging_operation to False to avoid _handle_staging_operation being called - for mock_rs in mock_result_sets: - mock_rs.is_staging_operation = False - mock_backend = ThriftDatabricksClientMockFactory.new() mock_backend.execute_command.side_effect = mock_result_sets @@ -266,9 +260,11 @@ def test_closed_cursor_doesnt_allow_operations(self): def test_negative_fetch_throws_exception(self): mock_backend = Mock() - mock_backend.fetch_results.return_value = (Mock(), False) + mock_backend.fetch_results.return_value = (Mock(), False, 0) - result_set = ThriftResultSet(Mock(), Mock(), mock_backend) + result_set = ThriftResultSet( + Mock(), Mock(), mock_backend + ) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) @@ -563,7 +559,10 @@ def test_cursor_keeps_connection_alive(self, mock_client_class): @patch("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME) @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_staging_operation_response_is_handled( - self, mock_client_class, mock_handle_staging_operation, mock_execute_response + self, + mock_client_class, + mock_handle_staging_operation, + mock_execute_response, ): # If server sets ExecuteResponse.is_staging_operation True then _handle_staging_operation should be called diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index 7dec4e680..faa8e2f99 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -4,7 +4,7 @@ pyarrow = None import unittest import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, Mock from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink import databricks.sql.utils as utils @@ -52,17 +52,20 @@ def get_schema_bytes(): return sink.getvalue().to_pybytes() @patch( - "databricks.sql.utils.CloudFetchQueue._create_next_table", + "databricks.sql.utils.ThriftCloudFetchQueue._create_next_table", return_value=[None, None], ) def test_initializer_adds_links(self, mock_create_next_table): schema_bytes = MagicMock() result_links = self.create_result_links(10) - queue = utils.CloudFetchQueue( + queue = utils.ThriftCloudFetchQueue( schema_bytes, result_links=result_links, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert len(queue.download_manager._pending_links) == 10 @@ -72,11 +75,14 @@ def test_initializer_adds_links(self, mock_create_next_table): def test_initializer_no_links_to_add(self): schema_bytes = MagicMock() result_links = [] - queue = utils.CloudFetchQueue( + queue = utils.ThriftCloudFetchQueue( schema_bytes, result_links=result_links, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert len(queue.download_manager._pending_links) == 0 @@ -88,11 +94,14 @@ def test_initializer_no_links_to_add(self): return_value=None, ) def test_create_next_table_no_download(self, mock_get_next_downloaded_file): - queue = utils.CloudFetchQueue( + queue = utils.ThriftCloudFetchQueue( MagicMock(), result_links=[], max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert queue._create_next_table() is None @@ -108,12 +117,15 @@ def test_initializer_create_next_table_success( ): mock_create_arrow_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue( + queue = utils.ThriftCloudFetchQueue( schema_bytes, result_links=[], description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) expected_result = self.make_arrow_table() @@ -129,16 +141,19 @@ def test_initializer_create_next_table_success( assert table.num_rows == 4 assert queue.start_row_index == 8 - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") def test_next_n_rows_0_rows(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue( + queue = utils.ThriftCloudFetchQueue( schema_bytes, result_links=[], description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -147,18 +162,20 @@ def test_next_n_rows_0_rows(self, mock_create_next_table): result = queue.next_n_rows(0) assert result.num_rows == 0 assert queue.table_row_index == 0 - assert result == self.make_arrow_table()[0:0] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") def test_next_n_rows_partial_table(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue( + queue = utils.ThriftCloudFetchQueue( schema_bytes, result_links=[], description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -169,16 +186,19 @@ def test_next_n_rows_partial_table(self, mock_create_next_table): assert queue.table_row_index == 3 assert result == self.make_arrow_table()[:3] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") def test_next_n_rows_more_than_one_table(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue( + queue = utils.ThriftCloudFetchQueue( schema_bytes, result_links=[], description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -194,16 +214,19 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): )[:7] ) - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue( + queue = utils.ThriftCloudFetchQueue( schema_bytes, result_links=[], description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -213,16 +236,22 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): assert result.num_rows == 4 assert result == self.make_arrow_table() - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=None) + @patch( + "databricks.sql.utils.ThriftCloudFetchQueue._create_next_table", + return_value=None, + ) def test_next_n_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() description = MagicMock() - queue = utils.CloudFetchQueue( + queue = utils.ThriftCloudFetchQueue( schema_bytes, result_links=[], description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert queue.table is None @@ -230,16 +259,19 @@ def test_next_n_rows_empty_table(self, mock_create_next_table): mock_create_next_table.assert_called() assert result == pyarrow.ipc.open_stream(bytearray(schema_bytes)).read_all() - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None, 0] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue( + queue = utils.ThriftCloudFetchQueue( schema_bytes, result_links=[], description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -249,16 +281,19 @@ def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table) assert result.num_rows == 0 assert result == self.make_arrow_table()[0:0] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue( + queue = utils.ThriftCloudFetchQueue( schema_bytes, result_links=[], description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -268,16 +303,19 @@ def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_tabl assert result.num_rows == 2 assert result == self.make_arrow_table()[2:] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue( + queue = utils.ThriftCloudFetchQueue( schema_bytes, result_links=[], description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -287,7 +325,7 @@ def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): assert result.num_rows == 4 assert result == self.make_arrow_table() - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") def test_remaining_rows_multiple_tables_fully_returned( self, mock_create_next_table ): @@ -297,12 +335,15 @@ def test_remaining_rows_multiple_tables_fully_returned( None, ] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue( + queue = utils.ThriftCloudFetchQueue( schema_bytes, result_links=[], description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -318,16 +359,22 @@ def test_remaining_rows_multiple_tables_fully_returned( )[3:] ) - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=None) + @patch( + "databricks.sql.utils.ThriftCloudFetchQueue._create_next_table", + return_value=None, + ) def test_remaining_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() description = MagicMock() - queue = utils.CloudFetchQueue( + queue = utils.ThriftCloudFetchQueue( schema_bytes, result_links=[], description=description, max_download_threads=10, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) assert queue.table is None diff --git a/tests/unit/test_download_manager.py b/tests/unit/test_download_manager.py index 64edbdebe..6eb17a05a 100644 --- a/tests/unit/test_download_manager.py +++ b/tests/unit/test_download_manager.py @@ -1,5 +1,5 @@ import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, Mock import databricks.sql.cloudfetch.download_manager as download_manager from databricks.sql.types import SSLOptions @@ -19,6 +19,9 @@ def create_download_manager( max_download_threads, lz4_compressed, ssl_options=SSLOptions(), + session_id_hex=Mock(), + statement_id=Mock(), + chunk_id=0, ) def create_result_link( diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index 1013ba999..a7cd92a51 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -30,7 +30,12 @@ def test_run_link_expired(self, mock_time): # Already expired result_link.expiryTime = 999 d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() + settings, + result_link, + ssl_options=SSLOptions(), + chunk_id=0, + session_id_hex=Mock(), + statement_id=Mock(), ) with self.assertRaises(Error) as context: @@ -46,7 +51,12 @@ def test_run_link_past_expiry_buffer(self, mock_time): # Within the expiry buffer time result_link.expiryTime = 1004 d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() + settings, + result_link, + ssl_options=SSLOptions(), + chunk_id=0, + session_id_hex=Mock(), + statement_id=Mock(), ) with self.assertRaises(Error) as context: @@ -69,7 +79,12 @@ def test_run_get_response_not_ok(self, mock_time): return_value=create_response(status_code=404, _content=b"1234"), ): d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() + settings, + result_link, + ssl_options=SSLOptions(), + chunk_id=0, + session_id_hex=Mock(), + statement_id=Mock(), ) with self.assertRaises(requests.exceptions.HTTPError) as context: d.run() @@ -89,7 +104,12 @@ def test_run_uncompressed_successful(self, mock_time): return_value=create_response(status_code=200, _content=file_bytes), ): d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() + settings, + result_link, + ssl_options=SSLOptions(), + chunk_id=0, + session_id_hex=Mock(), + statement_id=Mock(), ) file = d.run() @@ -110,7 +130,12 @@ def test_run_compressed_successful(self, mock_time): return_value=create_response(status_code=200, _content=compressed_bytes), ): d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() + settings, + result_link, + ssl_options=SSLOptions(), + chunk_id=0, + session_id_hex=Mock(), + statement_id=Mock(), ) file = d.run() @@ -127,7 +152,12 @@ def test_download_connection_error(self, mock_time): with patch.object(http_client, "execute", side_effect=ConnectionError("foo")): d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() + settings, + result_link, + ssl_options=SSLOptions(), + chunk_id=0, + session_id_hex=Mock(), + statement_id=Mock(), ) with self.assertRaises(ConnectionError): d.run() @@ -142,7 +172,12 @@ def test_download_timeout(self, mock_time): with patch.object(http_client, "execute", side_effect=TimeoutError("foo")): d = downloader.ResultSetDownloadHandler( - settings, result_link, ssl_options=SSLOptions() + settings, + result_link, + ssl_options=SSLOptions(), + chunk_id=0, + session_id_hex=Mock(), + statement_id=Mock(), ) with self.assertRaises(TimeoutError): d.run() diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index a649941e1..7a0706838 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -43,7 +43,7 @@ def make_dummy_result_set_from_initial_results(initial_results): # Create a mock backend that will return the queue when _fill_results_buffer is called mock_thrift_backend = Mock(spec=ThriftDatabricksClient) - mock_thrift_backend.fetch_results.return_value = (arrow_queue, False) + mock_thrift_backend.fetch_results.return_value = (arrow_queue, False, 0) num_cols = len(initial_results[0]) if initial_results else 0 description = [ @@ -79,12 +79,13 @@ def fetch_results( arrow_schema_bytes, description, use_cloud_fetch=True, + chunk_id=0, ): nonlocal batch_index results = FetchTests.make_arrow_queue(batch_list[batch_index]) batch_index += 1 - return results, batch_index < len(batch_list) + return results, batch_index < len(batch_list), 0 mock_thrift_backend = Mock(spec=ThriftDatabricksClient) mock_thrift_backend.fetch_results = fetch_results diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index e4a9e5cdd..1d485ea61 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -36,11 +36,10 @@ def make_dummy_result_set_from_initial_results(arrow_table): execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, - is_direct_results=False, + has_more_rows=False, description=Mock(), command_id=None, - arrow_queue=arrow_queue, - arrow_schema=arrow_table.schema, + arrow_schema_bytes=arrow_table.schema, ), ) rs.description = [ diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py index 975376e13..13dfac006 100644 --- a/tests/unit/test_filters.py +++ b/tests/unit/test_filters.py @@ -77,7 +77,7 @@ def test_filter_by_column_values(self): "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True ): with patch( - "databricks.sql.result_set.SeaResultSet" + "databricks.sql.backend.sea.result_set.SeaResultSet" ) as mock_sea_result_set_class: mock_instance = MagicMock() mock_sea_result_set_class.return_value = mock_instance @@ -104,7 +104,7 @@ def test_filter_by_column_values(self): "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True ): with patch( - "databricks.sql.result_set.SeaResultSet" + "databricks.sql.backend.sea.result_set.SeaResultSet" ) as mock_sea_result_set_class: mock_instance = MagicMock() mock_sea_result_set_class.return_value = mock_instance diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 6d839162e..482ce655f 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -737,7 +737,7 @@ def test_get_schemas(self, sea_client, sea_session_id, mock_cursor): def test_get_tables(self, sea_client, sea_session_id, mock_cursor): """Test the get_tables method with various parameter combinations.""" # Mock the execute_command method - from databricks.sql.result_set import SeaResultSet + from databricks.sql.backend.sea.result_set import SeaResultSet mock_result_set = Mock(spec=SeaResultSet) diff --git a/tests/unit/test_sea_conversion.py b/tests/unit/test_sea_conversion.py new file mode 100644 index 000000000..13970c5db --- /dev/null +++ b/tests/unit/test_sea_conversion.py @@ -0,0 +1,130 @@ +""" +Tests for the conversion module in the SEA backend. + +This module contains tests for the SqlType and SqlTypeConverter classes. +""" + +import pytest +import datetime +import decimal +from unittest.mock import Mock, patch + +from databricks.sql.backend.sea.utils.conversion import SqlType, SqlTypeConverter + + +class TestSqlTypeConverter: + """Test suite for the SqlTypeConverter class.""" + + def test_convert_numeric_types(self): + """Test converting numeric types.""" + # Test integer types + assert SqlTypeConverter.convert_value("123", SqlType.BYTE) == 123 + assert SqlTypeConverter.convert_value("456", SqlType.SHORT) == 456 + assert SqlTypeConverter.convert_value("789", SqlType.INT) == 789 + assert SqlTypeConverter.convert_value("1234567890", SqlType.LONG) == 1234567890 + + # Test floating point types + assert SqlTypeConverter.convert_value("123.45", SqlType.FLOAT) == 123.45 + assert SqlTypeConverter.convert_value("678.90", SqlType.DOUBLE) == 678.90 + + # Test decimal type + decimal_value = SqlTypeConverter.convert_value("123.45", SqlType.DECIMAL) + assert isinstance(decimal_value, decimal.Decimal) + assert decimal_value == decimal.Decimal("123.45") + + # Test decimal with precision and scale + decimal_value = SqlTypeConverter.convert_value( + "123.45", SqlType.DECIMAL, precision=5, scale=2 + ) + assert isinstance(decimal_value, decimal.Decimal) + assert decimal_value == decimal.Decimal("123.45") + + # Test invalid numeric input + result = SqlTypeConverter.convert_value("not_a_number", SqlType.INT) + assert result == "not_a_number" # Returns original value on error + + def test_convert_boolean_type(self): + """Test converting boolean types.""" + # True values + assert SqlTypeConverter.convert_value("true", SqlType.BOOLEAN) is True + assert SqlTypeConverter.convert_value("True", SqlType.BOOLEAN) is True + assert SqlTypeConverter.convert_value("t", SqlType.BOOLEAN) is True + assert SqlTypeConverter.convert_value("1", SqlType.BOOLEAN) is True + assert SqlTypeConverter.convert_value("yes", SqlType.BOOLEAN) is True + assert SqlTypeConverter.convert_value("y", SqlType.BOOLEAN) is True + + # False values + assert SqlTypeConverter.convert_value("false", SqlType.BOOLEAN) is False + assert SqlTypeConverter.convert_value("False", SqlType.BOOLEAN) is False + assert SqlTypeConverter.convert_value("f", SqlType.BOOLEAN) is False + assert SqlTypeConverter.convert_value("0", SqlType.BOOLEAN) is False + assert SqlTypeConverter.convert_value("no", SqlType.BOOLEAN) is False + assert SqlTypeConverter.convert_value("n", SqlType.BOOLEAN) is False + + def test_convert_datetime_types(self): + """Test converting datetime types.""" + # Test date type + date_value = SqlTypeConverter.convert_value("2023-01-15", SqlType.DATE) + assert isinstance(date_value, datetime.date) + assert date_value == datetime.date(2023, 1, 15) + + # Test timestamp type + timestamp_value = SqlTypeConverter.convert_value( + "2023-01-15T12:30:45", SqlType.TIMESTAMP + ) + assert isinstance(timestamp_value, datetime.datetime) + assert timestamp_value.year == 2023 + assert timestamp_value.month == 1 + assert timestamp_value.day == 15 + assert timestamp_value.hour == 12 + assert timestamp_value.minute == 30 + assert timestamp_value.second == 45 + + # Test interval type (currently returns as string) + interval_value = SqlTypeConverter.convert_value( + "1 day 2 hours", SqlType.INTERVAL + ) + assert interval_value == "1 day 2 hours" + + # Test invalid date input + result = SqlTypeConverter.convert_value("not_a_date", SqlType.DATE) + assert result == "not_a_date" # Returns original value on error + + def test_convert_string_types(self): + """Test converting string types.""" + # String types don't need conversion, they should be returned as-is + assert ( + SqlTypeConverter.convert_value("test string", SqlType.STRING) + == "test string" + ) + assert SqlTypeConverter.convert_value("test char", SqlType.CHAR) == "test char" + + def test_convert_binary_type(self): + """Test converting binary type.""" + # Test valid hex string + binary_value = SqlTypeConverter.convert_value("48656C6C6F", SqlType.BINARY) + assert isinstance(binary_value, bytes) + assert binary_value == b"Hello" + + # Test invalid binary input + result = SqlTypeConverter.convert_value("not_hex", SqlType.BINARY) + assert result == "not_hex" # Returns original value on error + + def test_convert_unsupported_type(self): + """Test converting an unsupported type.""" + # Should return the original value + assert SqlTypeConverter.convert_value("test", "unsupported_type") == "test" + + # Complex types should return as-is + assert ( + SqlTypeConverter.convert_value("complex_value", SqlType.ARRAY) + == "complex_value" + ) + assert ( + SqlTypeConverter.convert_value("complex_value", SqlType.MAP) + == "complex_value" + ) + assert ( + SqlTypeConverter.convert_value("complex_value", SqlType.STRUCT) + == "complex_value" + ) diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py new file mode 100644 index 000000000..cbeae098b --- /dev/null +++ b/tests/unit/test_sea_queue.py @@ -0,0 +1,720 @@ +""" +Tests for SEA-related queue classes. + +This module contains tests for the JsonQueue, SeaResultSetQueueFactory, and SeaCloudFetchQueue classes. +It also tests the Hybrid disposition which can create either ArrowQueue or SeaCloudFetchQueue based on +whether attachment is set. +""" + +import pytest +from unittest.mock import Mock, patch + +from databricks.sql.backend.sea.queue import ( + JsonQueue, + LinkFetcher, + SeaResultSetQueueFactory, + SeaCloudFetchQueue, +) +from databricks.sql.backend.sea.models.base import ( + ResultData, + ResultManifest, + ExternalLink, +) +from databricks.sql.backend.sea.utils.constants import ResultFormat +from databricks.sql.exc import ProgrammingError, ServerOperationError +from databricks.sql.types import SSLOptions +from databricks.sql.utils import ArrowQueue +import threading +import time + + +class TestJsonQueue: + """Test suite for the JsonQueue class.""" + + @pytest.fixture + def sample_data(self): + """Create sample data for testing.""" + return [ + ["value1", 1, True], + ["value2", 2, False], + ["value3", 3, True], + ["value4", 4, False], + ["value5", 5, True], + ] + + def test_init(self, sample_data): + """Test initialization of JsonQueue.""" + queue = JsonQueue(sample_data) + assert queue.data_array == sample_data + assert queue.cur_row_index == 0 + assert queue.num_rows == len(sample_data) + + def test_init_with_none(self): + """Test initialization with None data.""" + queue = JsonQueue(None) + assert queue.data_array == [] + assert queue.cur_row_index == 0 + assert queue.num_rows == 0 + + def test_next_n_rows_partial(self, sample_data): + """Test fetching a subset of rows.""" + queue = JsonQueue(sample_data) + result = queue.next_n_rows(2) + assert result == sample_data[:2] + assert queue.cur_row_index == 2 + + def test_next_n_rows_all(self, sample_data): + """Test fetching all rows.""" + queue = JsonQueue(sample_data) + result = queue.next_n_rows(len(sample_data)) + assert result == sample_data + assert queue.cur_row_index == len(sample_data) + + def test_next_n_rows_more_than_available(self, sample_data): + """Test fetching more rows than available.""" + queue = JsonQueue(sample_data) + result = queue.next_n_rows(len(sample_data) + 10) + assert result == sample_data + assert queue.cur_row_index == len(sample_data) + + def test_next_n_rows_zero(self, sample_data): + """Test fetching zero rows.""" + queue = JsonQueue(sample_data) + result = queue.next_n_rows(0) + assert result == [] + assert queue.cur_row_index == 0 + + def test_remaining_rows(self, sample_data): + """Test fetching all remaining rows.""" + queue = JsonQueue(sample_data) + + # Fetch some rows first + queue.next_n_rows(2) + + # Now fetch remaining + result = queue.remaining_rows() + assert result == sample_data[2:] + assert queue.cur_row_index == len(sample_data) + + def test_remaining_rows_all(self, sample_data): + """Test fetching all remaining rows from the start.""" + queue = JsonQueue(sample_data) + result = queue.remaining_rows() + assert result == sample_data + assert queue.cur_row_index == len(sample_data) + + def test_remaining_rows_empty(self, sample_data): + """Test fetching remaining rows when none are left.""" + queue = JsonQueue(sample_data) + + # Fetch all rows first + queue.next_n_rows(len(sample_data)) + + # Now fetch remaining (should be empty) + result = queue.remaining_rows() + assert result == [] + assert queue.cur_row_index == len(sample_data) + + +class TestSeaResultSetQueueFactory: + """Test suite for the SeaResultSetQueueFactory class.""" + + @pytest.fixture + def json_manifest(self): + """Create a JSON manifest for testing.""" + return ResultManifest( + format=ResultFormat.JSON_ARRAY.value, + schema={}, + total_row_count=5, + total_byte_count=1000, + total_chunk_count=1, + ) + + @pytest.fixture + def arrow_manifest(self): + """Create an Arrow manifest for testing.""" + return ResultManifest( + format=ResultFormat.ARROW_STREAM.value, + schema={}, + total_row_count=5, + total_byte_count=1000, + total_chunk_count=1, + ) + + @pytest.fixture + def invalid_manifest(self): + """Create an invalid manifest for testing.""" + return ResultManifest( + format="INVALID_FORMAT", + schema={}, + total_row_count=5, + total_byte_count=1000, + total_chunk_count=1, + ) + + @pytest.fixture + def sample_data(self): + """Create sample result data.""" + return [ + ["value1", "1", "true"], + ["value2", "2", "false"], + ] + + @pytest.fixture + def ssl_options(self): + """Create SSL options for testing.""" + return SSLOptions(tls_verify=True) + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + client = Mock() + client.max_download_threads = 10 + return client + + @pytest.fixture + def description(self): + """Create column descriptions.""" + return [ + ("col1", "string", None, None, None, None, None), + ("col2", "int", None, None, None, None, None), + ("col3", "boolean", None, None, None, None, None), + ] + + def test_build_queue_json_array(self, json_manifest, sample_data): + """Test building a JSON array queue.""" + result_data = ResultData(data=sample_data) + + queue = SeaResultSetQueueFactory.build_queue( + result_data=result_data, + manifest=json_manifest, + statement_id="test-statement", + ssl_options=SSLOptions(), + description=[], + max_download_threads=10, + sea_client=Mock(), + lz4_compressed=False, + ) + + assert isinstance(queue, JsonQueue) + assert queue.data_array == sample_data + + def test_build_queue_arrow_stream( + self, arrow_manifest, ssl_options, mock_sea_client, description + ): + """Test building an Arrow stream queue.""" + external_links = [ + ExternalLink( + external_link="https://example.com/data/chunk0", + expiration="2025-07-03T05:51:18.118009", + row_count=100, + byte_count=1024, + row_offset=0, + chunk_index=0, + next_chunk_index=1, + http_headers={"Authorization": "Bearer token123"}, + ) + ] + result_data = ResultData(data=None, external_links=external_links) + + with patch( + "databricks.sql.backend.sea.queue.ResultFileDownloadManager" + ), patch.object(SeaCloudFetchQueue, "_create_next_table", return_value=None): + queue = SeaResultSetQueueFactory.build_queue( + result_data=result_data, + manifest=arrow_manifest, + statement_id="test-statement", + ssl_options=ssl_options, + description=description, + max_download_threads=10, + sea_client=mock_sea_client, + lz4_compressed=False, + ) + + assert isinstance(queue, SeaCloudFetchQueue) + + def test_build_queue_invalid_format(self, invalid_manifest): + """Test building a queue with invalid format.""" + result_data = ResultData(data=[]) + + with pytest.raises(ProgrammingError, match="Invalid result format"): + SeaResultSetQueueFactory.build_queue( + result_data=result_data, + manifest=invalid_manifest, + statement_id="test-statement", + ssl_options=SSLOptions(), + description=[], + max_download_threads=10, + sea_client=Mock(), + lz4_compressed=False, + ) + + +class TestSeaCloudFetchQueue: + """Test suite for the SeaCloudFetchQueue class.""" + + @pytest.fixture + def ssl_options(self): + """Create SSL options for testing.""" + return SSLOptions(tls_verify=True) + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + client = Mock() + client.max_download_threads = 10 + return client + + @pytest.fixture + def description(self): + """Create column descriptions.""" + return [ + ("col1", "string", None, None, None, None, None), + ("col2", "int", None, None, None, None, None), + ("col3", "boolean", None, None, None, None, None), + ] + + @pytest.fixture + def sample_external_link(self): + """Create a sample external link.""" + return ExternalLink( + external_link="https://example.com/data/chunk0", + expiration="2025-07-03T05:51:18.118009", + row_count=100, + byte_count=1024, + row_offset=0, + chunk_index=0, + next_chunk_index=1, + http_headers={"Authorization": "Bearer token123"}, + ) + + @pytest.fixture + def sample_external_link_no_headers(self): + """Create a sample external link without headers.""" + return ExternalLink( + external_link="https://example.com/data/chunk0", + expiration="2025-07-03T05:51:18.118009", + row_count=100, + byte_count=1024, + row_offset=0, + chunk_index=0, + next_chunk_index=1, + http_headers=None, + ) + + def test_convert_to_thrift_link(self, sample_external_link): + """Test conversion of ExternalLink to TSparkArrowResultLink.""" + # Call the method directly + result = LinkFetcher._convert_to_thrift_link(sample_external_link) + + # Verify the conversion + assert result.fileLink == sample_external_link.external_link + assert result.rowCount == sample_external_link.row_count + assert result.bytesNum == sample_external_link.byte_count + assert result.startRowOffset == sample_external_link.row_offset + assert result.httpHeaders == sample_external_link.http_headers + + def test_convert_to_thrift_link_no_headers(self, sample_external_link_no_headers): + """Test conversion of ExternalLink with no headers to TSparkArrowResultLink.""" + # Call the method directly + result = LinkFetcher._convert_to_thrift_link(sample_external_link_no_headers) + + # Verify the conversion + assert result.fileLink == sample_external_link_no_headers.external_link + assert result.rowCount == sample_external_link_no_headers.row_count + assert result.bytesNum == sample_external_link_no_headers.byte_count + assert result.startRowOffset == sample_external_link_no_headers.row_offset + assert result.httpHeaders == {} + + @patch("databricks.sql.backend.sea.queue.ResultFileDownloadManager") + @patch("databricks.sql.backend.sea.queue.logger") + def test_init_with_valid_initial_link( + self, + mock_logger, + mock_download_manager_class, + mock_sea_client, + ssl_options, + description, + sample_external_link, + ): + """Test initialization with valid initial link.""" + # Create a queue with valid initial link + with patch.object(SeaCloudFetchQueue, "_create_next_table", return_value=None): + queue = SeaCloudFetchQueue( + result_data=ResultData(external_links=[sample_external_link]), + max_download_threads=5, + ssl_options=ssl_options, + sea_client=mock_sea_client, + statement_id="test-statement-123", + total_chunk_count=1, + lz4_compressed=False, + description=description, + ) + + # Verify attributes + assert queue._current_chunk_index == 0 + assert queue.link_fetcher is not None + + @patch("databricks.sql.backend.sea.queue.ResultFileDownloadManager") + @patch("databricks.sql.backend.sea.queue.logger") + def test_init_no_initial_links( + self, + mock_logger, + mock_download_manager_class, + mock_sea_client, + ssl_options, + description, + ): + """Test initialization with no initial links.""" + # Create a queue with empty initial links + queue = SeaCloudFetchQueue( + result_data=ResultData(external_links=[]), + max_download_threads=5, + ssl_options=ssl_options, + sea_client=mock_sea_client, + statement_id="test-statement-123", + total_chunk_count=0, + lz4_compressed=False, + description=description, + ) + assert queue.table is None + + @patch("databricks.sql.backend.sea.queue.logger") + def test_create_next_table_success(self, mock_logger): + """Test _create_next_table with successful table creation.""" + # Create a queue instance without initializing + queue = Mock(spec=SeaCloudFetchQueue) + queue._current_chunk_index = 0 + queue.download_manager = Mock() + queue.link_fetcher = Mock() + + # Mock the dependencies + mock_table = Mock() + mock_chunk_link = Mock() + queue.link_fetcher.get_chunk_link = Mock(return_value=mock_chunk_link) + queue._create_table_at_offset = Mock(return_value=mock_table) + + # Call the method directly + SeaCloudFetchQueue._create_next_table(queue) + + # Verify the chunk index was incremented + assert queue._current_chunk_index == 1 + + # Verify the chunk link was retrieved + queue.link_fetcher.get_chunk_link.assert_called_once_with(0) + + # Verify the table was created from the link + queue._create_table_at_offset.assert_called_once_with( + mock_chunk_link.row_offset + ) + + +class TestHybridDisposition: + """Test suite for the Hybrid disposition handling in SeaResultSetQueueFactory.""" + + @pytest.fixture + def arrow_manifest(self): + """Create an Arrow manifest for testing.""" + return ResultManifest( + format=ResultFormat.ARROW_STREAM.value, + schema={}, + total_row_count=5, + total_byte_count=1000, + total_chunk_count=1, + ) + + @pytest.fixture + def description(self): + """Create column descriptions.""" + return [ + ("col1", "string", None, None, None, None, None), + ("col2", "int", None, None, None, None, None), + ("col3", "boolean", None, None, None, None, None), + ] + + @pytest.fixture + def ssl_options(self): + """Create SSL options for testing.""" + return SSLOptions(tls_verify=True) + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + client = Mock() + client.max_download_threads = 10 + return client + + @patch("databricks.sql.backend.sea.queue.create_arrow_table_from_arrow_file") + def test_hybrid_disposition_with_attachment( + self, + mock_create_table, + arrow_manifest, + description, + ssl_options, + mock_sea_client, + ): + """Test that ArrowQueue is created when attachment is present.""" + # Create mock arrow table + mock_arrow_table = Mock() + mock_arrow_table.num_rows = 5 + mock_create_table.return_value = mock_arrow_table + + # Create result data with attachment + attachment_data = b"mock_arrow_data" + result_data = ResultData(attachment=attachment_data) + + # Build queue + queue = SeaResultSetQueueFactory.build_queue( + result_data=result_data, + manifest=arrow_manifest, + statement_id="test-statement", + ssl_options=ssl_options, + description=description, + max_download_threads=10, + sea_client=mock_sea_client, + lz4_compressed=False, + ) + + # Verify ArrowQueue was created + assert isinstance(queue, ArrowQueue) + mock_create_table.assert_called_once_with(attachment_data, description) + + @patch("databricks.sql.backend.sea.queue.ResultFileDownloadManager") + @patch.object(SeaCloudFetchQueue, "_create_next_table", return_value=None) + def test_hybrid_disposition_with_external_links( + self, + mock_create_table, + mock_download_manager, + arrow_manifest, + description, + ssl_options, + mock_sea_client, + ): + """Test that SeaCloudFetchQueue is created when attachment is None but external links are present.""" + # Create external links + external_links = [ + ExternalLink( + external_link="https://example.com/data/chunk0", + expiration="2025-07-03T05:51:18.118009", + row_count=100, + byte_count=1024, + row_offset=0, + chunk_index=0, + next_chunk_index=1, + http_headers={"Authorization": "Bearer token123"}, + ) + ] + + # Create result data with external links but no attachment + result_data = ResultData(external_links=external_links, attachment=None) + + # Build queue + queue = SeaResultSetQueueFactory.build_queue( + result_data=result_data, + manifest=arrow_manifest, + statement_id="test-statement", + ssl_options=ssl_options, + description=description, + max_download_threads=10, + sea_client=mock_sea_client, + lz4_compressed=False, + ) + + # Verify SeaCloudFetchQueue was created + assert isinstance(queue, SeaCloudFetchQueue) + mock_create_table.assert_called_once() + + @patch("databricks.sql.backend.sea.queue.ResultSetDownloadHandler._decompress_data") + @patch("databricks.sql.backend.sea.queue.create_arrow_table_from_arrow_file") + def test_hybrid_disposition_with_compressed_attachment( + self, + mock_create_table, + mock_decompress, + arrow_manifest, + description, + ssl_options, + mock_sea_client, + ): + """Test that ArrowQueue is created with decompressed data when attachment is present and lz4_compressed is True.""" + # Create mock arrow table + mock_arrow_table = Mock() + mock_arrow_table.num_rows = 5 + mock_create_table.return_value = mock_arrow_table + + # Setup decompression mock + compressed_data = b"compressed_data" + decompressed_data = b"decompressed_data" + mock_decompress.return_value = decompressed_data + + # Create result data with attachment + result_data = ResultData(attachment=compressed_data) + + # Build queue with lz4_compressed=True + queue = SeaResultSetQueueFactory.build_queue( + result_data=result_data, + manifest=arrow_manifest, + statement_id="test-statement", + ssl_options=ssl_options, + description=description, + max_download_threads=10, + sea_client=mock_sea_client, + lz4_compressed=True, + ) + + # Verify ArrowQueue was created with decompressed data + assert isinstance(queue, ArrowQueue) + mock_decompress.assert_called_once_with(compressed_data) + mock_create_table.assert_called_once_with(decompressed_data, description) + + +class TestLinkFetcher: + """Unit tests for the LinkFetcher helper class.""" + + @pytest.fixture + def sample_links(self): + """Provide a pair of ExternalLink objects forming two sequential chunks.""" + link0 = ExternalLink( + external_link="https://example.com/data/chunk0", + expiration="2030-01-01T00:00:00.000000", + row_count=100, + byte_count=1024, + row_offset=0, + chunk_index=0, + next_chunk_index=1, + http_headers={"Authorization": "Bearer token0"}, + ) + + link1 = ExternalLink( + external_link="https://example.com/data/chunk1", + expiration="2030-01-01T00:00:00.000000", + row_count=100, + byte_count=1024, + row_offset=100, + chunk_index=1, + next_chunk_index=None, + http_headers={"Authorization": "Bearer token1"}, + ) + + return link0, link1 + + def _create_fetcher( + self, + initial_links, + backend_mock=None, + download_manager_mock=None, + total_chunk_count=10, + ): + """Helper to create a LinkFetcher instance with supplied mocks.""" + if backend_mock is None: + backend_mock = Mock() + if download_manager_mock is None: + download_manager_mock = Mock() + + return ( + LinkFetcher( + download_manager=download_manager_mock, + backend=backend_mock, + statement_id="statement-123", + initial_links=list(initial_links), + total_chunk_count=total_chunk_count, + ), + backend_mock, + download_manager_mock, + ) + + def test_add_links_and_get_next_chunk_index(self, sample_links): + """Verify that initial links are stored and next chunk index is computed correctly.""" + link0, link1 = sample_links + + fetcher, _backend, download_manager = self._create_fetcher([link0]) + + # add_link should have been called for the initial link + download_manager.add_link.assert_called_once() + + # Internal mapping should contain the link + assert fetcher.chunk_index_to_link[0] == link0 + + # The next chunk index should be 1 (from link0.next_chunk_index) + assert fetcher._get_next_chunk_index() == 1 + + # Add second link and validate it is present + fetcher._add_links([link1]) + assert fetcher.chunk_index_to_link[1] == link1 + + def test_trigger_next_batch_download_success(self, sample_links): + """Check that _trigger_next_batch_download fetches and stores new links.""" + link0, link1 = sample_links + + backend_mock = Mock() + backend_mock.get_chunk_links = Mock(return_value=[link1]) + + fetcher, backend, download_manager = self._create_fetcher( + [link0], backend_mock=backend_mock + ) + + # Trigger download of the next chunk (index 1) + success = fetcher._trigger_next_batch_download() + + assert success is True + backend.get_chunk_links.assert_called_once_with("statement-123", 1) + assert fetcher.chunk_index_to_link[1] == link1 + # Two calls to add_link: one for initial link, one for new link + assert download_manager.add_link.call_count == 2 + + def test_trigger_next_batch_download_error(self, sample_links): + """Ensure that errors from backend are captured and surfaced.""" + link0, _link1 = sample_links + + backend_mock = Mock() + backend_mock.get_chunk_links.side_effect = ServerOperationError( + "Backend failure" + ) + + fetcher, backend, download_manager = self._create_fetcher( + [link0], backend_mock=backend_mock + ) + + success = fetcher._trigger_next_batch_download() + + assert success is False + assert fetcher._error is not None + + def test_get_chunk_link_waits_until_available(self, sample_links): + """Validate that get_chunk_link blocks until the requested link is available and then returns it.""" + link0, link1 = sample_links + + backend_mock = Mock() + # Configure backend to return link1 when requested for chunk index 1 + backend_mock.get_chunk_links = Mock(return_value=[link1]) + + fetcher, backend, download_manager = self._create_fetcher( + [link0], backend_mock=backend_mock, total_chunk_count=2 + ) + + # Holder to capture the link returned from the background thread + result_container = {} + + def _worker(): + result_container["link"] = fetcher.get_chunk_link(1) + + thread = threading.Thread(target=_worker) + thread.start() + + # Give the thread a brief moment to start and attempt to fetch (and therefore block) + time.sleep(0.1) + + # Trigger the backend fetch which will add link1 and notify waiting threads + fetcher._trigger_next_batch_download() + + thread.join(timeout=2) + + # The thread should have finished and captured link1 + assert result_container.get("link") == link1 + + def test_get_chunk_link_out_of_range_returns_none(self, sample_links): + """Requesting a chunk index >= total_chunk_count should immediately return None.""" + link0, _ = sample_links + + fetcher, _backend, _dm = self._create_fetcher([link0], total_chunk_count=1) + + assert fetcher.get_chunk_link(10) is None diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index c596dbc14..c42e66659 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -6,10 +6,18 @@ """ import pytest -from unittest.mock import patch, MagicMock, Mock +from unittest.mock import Mock, patch -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import CommandId, CommandState, BackendType +try: + import pyarrow +except ImportError: + pyarrow = None + +from databricks.sql.backend.sea.result_set import SeaResultSet, Row +from databricks.sql.backend.sea.queue import JsonQueue +from databricks.sql.backend.sea.utils.constants import ResultFormat +from databricks.sql.backend.types import CommandId, CommandState +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest class TestSeaResultSet: @@ -20,12 +28,16 @@ def mock_connection(self): """Create a mock connection.""" connection = Mock() connection.open = True + connection.session = Mock() + connection.session.ssl_options = Mock() return connection @pytest.fixture def mock_sea_client(self): """Create a mock SEA client.""" - return Mock() + client = Mock() + client.max_download_threads = 10 + return client @pytest.fixture def execute_response(self): @@ -34,25 +46,163 @@ def execute_response(self): mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") mock_response.status = CommandState.SUCCEEDED mock_response.has_been_closed_server_side = False - mock_response.is_direct_results = False + mock_response.has_more_rows = False mock_response.results_queue = None mock_response.description = [ - ("test_value", "INT", None, None, None, None, None) + ("col1", "string", None, None, None, None, None), + ("col2", "int", None, None, None, None, None), + ("col3", "boolean", None, None, None, None, None), ] mock_response.is_staging_operation = False + mock_response.lz4_compressed = False + mock_response.arrow_schema_bytes = None return mock_response + @pytest.fixture + def sample_data(self): + """Create sample data for testing.""" + return [ + ["value1", "1", "true"], + ["value2", "2", "false"], + ["value3", "3", "true"], + ["value4", "4", "false"], + ["value5", "5", "true"], + ] + + def _create_empty_manifest(self, format: ResultFormat): + """Create an empty manifest.""" + return ResultManifest( + format=format.value, + schema={}, + total_row_count=-1, + total_byte_count=-1, + total_chunk_count=-1, + ) + + @pytest.fixture + def result_set_with_data( + self, mock_connection, mock_sea_client, execute_response, sample_data + ): + """Create a SeaResultSet with sample data.""" + # Create ResultData with inline data + result_data = ResultData( + data=sample_data, external_links=None, row_count=len(sample_data) + ) + + # Initialize SeaResultSet with result data + with patch( + "databricks.sql.backend.sea.queue.SeaResultSetQueueFactory.build_queue", + return_value=JsonQueue(sample_data), + ): + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + result_data=result_data, + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), + buffer_size_bytes=1000, + arraysize=100, + ) + + return result_set + + @pytest.fixture + def mock_arrow_queue(self): + """Create a mock Arrow queue.""" + queue = Mock() + if pyarrow is not None: + queue.next_n_rows.return_value = Mock(spec=pyarrow.Table) + queue.next_n_rows.return_value.num_rows = 0 + queue.remaining_rows.return_value = Mock(spec=pyarrow.Table) + queue.remaining_rows.return_value.num_rows = 0 + return queue + + @pytest.fixture + def mock_json_queue(self): + """Create a mock JSON queue.""" + queue = Mock(spec=JsonQueue) + queue.next_n_rows.return_value = [] + queue.remaining_rows.return_value = [] + return queue + + @pytest.fixture + def result_set_with_arrow_queue( + self, mock_connection, mock_sea_client, execute_response, mock_arrow_queue + ): + """Create a SeaResultSet with an Arrow queue.""" + # Create ResultData with external links + result_data = ResultData(data=None, external_links=[], row_count=0) + + # Initialize SeaResultSet with result data + with patch( + "databricks.sql.backend.sea.queue.SeaResultSetQueueFactory.build_queue", + return_value=mock_arrow_queue, + ): + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + result_data=result_data, + manifest=ResultManifest( + format=ResultFormat.ARROW_STREAM.value, + schema={}, + total_row_count=0, + total_byte_count=0, + total_chunk_count=0, + ), + buffer_size_bytes=1000, + arraysize=100, + ) + + return result_set + + @pytest.fixture + def result_set_with_json_queue( + self, mock_connection, mock_sea_client, execute_response, mock_json_queue + ): + """Create a SeaResultSet with a JSON queue.""" + # Create ResultData with inline data + result_data = ResultData(data=[], external_links=None, row_count=0) + + # Initialize SeaResultSet with result data + with patch( + "databricks.sql.backend.sea.queue.SeaResultSetQueueFactory.build_queue", + return_value=mock_json_queue, + ): + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + result_data=result_data, + manifest=ResultManifest( + format=ResultFormat.JSON_ARRAY.value, + schema={}, + total_row_count=0, + total_byte_count=0, + total_chunk_count=0, + ), + buffer_size_bytes=1000, + arraysize=100, + ) + + return result_set + def test_init_with_execute_response( self, mock_connection, mock_sea_client, execute_response ): """Test initializing SeaResultSet with an execute response.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) + with patch( + "databricks.sql.backend.sea.queue.SeaResultSetQueueFactory.build_queue" + ): + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + result_data=ResultData(data=[]), + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), + buffer_size_bytes=1000, + arraysize=100, + ) # Verify basic properties assert result_set.command_id == execute_response.command_id @@ -63,15 +213,40 @@ def test_init_with_execute_response( assert result_set.arraysize == 100 assert result_set.description == execute_response.description + def test_init_with_invalid_command_id( + self, mock_connection, mock_sea_client, execute_response + ): + """Test initializing SeaResultSet with invalid command ID.""" + # Mock the command ID to return None + mock_command_id = Mock() + mock_command_id.to_sea_statement_id.return_value = None + execute_response.command_id = mock_command_id + + with pytest.raises(ValueError, match="Command ID is not a SEA statement ID"): + SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + result_data=ResultData(data=[]), + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), + buffer_size_bytes=1000, + arraysize=100, + ) + def test_close(self, mock_connection, mock_sea_client, execute_response): """Test closing a result set.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) + with patch( + "databricks.sql.backend.sea.queue.SeaResultSetQueueFactory.build_queue" + ): + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + result_data=ResultData(data=[]), + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), + buffer_size_bytes=1000, + arraysize=100, + ) # Close the result set result_set.close() @@ -85,14 +260,19 @@ def test_close_when_already_closed_server_side( self, mock_connection, mock_sea_client, execute_response ): """Test closing a result set that has already been closed server-side.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - result_set.has_been_closed_server_side = True + with patch( + "databricks.sql.backend.sea.queue.SeaResultSetQueueFactory.build_queue" + ): + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + result_data=ResultData(data=[]), + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.has_been_closed_server_side = True # Close the result set result_set.close() @@ -107,13 +287,18 @@ def test_close_when_connection_closed( ): """Test closing a result set when the connection is closed.""" mock_connection.open = False - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) + with patch( + "databricks.sql.backend.sea.queue.SeaResultSetQueueFactory.build_queue" + ): + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + result_data=ResultData(data=[]), + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), + buffer_size_bytes=1000, + arraysize=100, + ) # Close the result set result_set.close() @@ -123,79 +308,307 @@ def test_close_when_connection_closed( assert result_set.has_been_closed_server_side is True assert result_set.status == CommandState.CLOSED - def test_unimplemented_methods( - self, mock_connection, mock_sea_client, execute_response - ): - """Test that unimplemented methods raise NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) + def test_convert_json_types(self, result_set_with_data, sample_data): + """Test the _convert_json_types method.""" + # Call _convert_json_types + converted_row = result_set_with_data._convert_json_types(sample_data[0]) - # Test each unimplemented method individually with specific error messages - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set.fetchone() + # Verify the conversion + assert converted_row[0] == "value1" # string stays as string + assert converted_row[1] == 1 # "1" converted to int + assert converted_row[2] is True # "true" converted to boolean - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - result_set.fetchmany(10) + @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") + def test_convert_json_to_arrow_table(self, result_set_with_data, sample_data): + """Test the _convert_json_to_arrow_table method.""" + # Call _convert_json_to_arrow_table + result_table = result_set_with_data._convert_json_to_arrow_table(sample_data) - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - # Test with default parameter value - result_set.fetchmany() + # Verify the result + assert isinstance(result_table, pyarrow.Table) + assert result_table.num_rows == len(sample_data) + assert result_table.num_columns == 3 - with pytest.raises( - NotImplementedError, match="fetchall is not implemented for SEA backend" - ): - result_set.fetchall() + @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") + def test_convert_json_to_arrow_table_empty(self, result_set_with_data): + """Test the _convert_json_to_arrow_table method with empty data.""" + # Call _convert_json_to_arrow_table with empty data + result_table = result_set_with_data._convert_json_to_arrow_table([]) - with pytest.raises( - NotImplementedError, - match="fetchmany_arrow is not implemented for SEA backend", - ): - result_set.fetchmany_arrow(10) + # Verify the result + assert isinstance(result_table, pyarrow.Table) + assert result_table.num_rows == 0 + + def test_create_json_table(self, result_set_with_data, sample_data): + """Test the _create_json_table method.""" + # Call _create_json_table + result_rows = result_set_with_data._create_json_table(sample_data) + + # Verify the result + assert len(result_rows) == len(sample_data) + assert isinstance(result_rows[0], Row) + assert result_rows[0].col1 == "value1" + assert result_rows[0].col2 == 1 + assert result_rows[0].col3 is True + + def test_fetchmany_json(self, result_set_with_data): + """Test the fetchmany_json method.""" + # Test fetching a subset of rows + result = result_set_with_data.fetchmany_json(2) + assert len(result) == 2 + assert result_set_with_data._next_row_index == 2 + + # Test fetching the next subset + result = result_set_with_data.fetchmany_json(2) + assert len(result) == 2 + assert result_set_with_data._next_row_index == 4 + + # Test fetching more than available + result = result_set_with_data.fetchmany_json(10) + assert len(result) == 1 # Only one row left + assert result_set_with_data._next_row_index == 5 + def test_fetchmany_json_negative_size(self, result_set_with_data): + """Test the fetchmany_json method with negative size.""" with pytest.raises( - NotImplementedError, - match="fetchall_arrow is not implemented for SEA backend", + ValueError, match="size argument for fetchmany is -1 but must be >= 0" ): - result_set.fetchall_arrow() + result_set_with_data.fetchmany_json(-1) + def test_fetchall_json(self, result_set_with_data, sample_data): + """Test the fetchall_json method.""" + # Test fetching all rows + result = result_set_with_data.fetchall_json() + assert result == sample_data + assert result_set_with_data._next_row_index == len(sample_data) + + # Test fetching again (should return empty) + result = result_set_with_data.fetchall_json() + assert result == [] + assert result_set_with_data._next_row_index == len(sample_data) + + @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") + def test_fetchmany_arrow(self, result_set_with_data, sample_data): + """Test the fetchmany_arrow method.""" + # Test with JSON queue (should convert to Arrow) + result = result_set_with_data.fetchmany_arrow(2) + assert isinstance(result, pyarrow.Table) + assert result.num_rows == 2 + assert result_set_with_data._next_row_index == 2 + + @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") + def test_fetchmany_arrow_negative_size(self, result_set_with_data): + """Test the fetchmany_arrow method with negative size.""" with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" + ValueError, match="size argument for fetchmany is -1 but must be >= 0" ): - # Test iteration protocol (calls fetchone internally) - next(iter(result_set)) + result_set_with_data.fetchmany_arrow(-1) + + @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") + def test_fetchall_arrow(self, result_set_with_data, sample_data): + """Test the fetchall_arrow method.""" + # Test with JSON queue (should convert to Arrow) + result = result_set_with_data.fetchall_arrow() + assert isinstance(result, pyarrow.Table) + assert result.num_rows == len(sample_data) + assert result_set_with_data._next_row_index == len(sample_data) + + def test_fetchone(self, result_set_with_data): + """Test the fetchone method.""" + # Test fetching one row at a time + row1 = result_set_with_data.fetchone() + assert isinstance(row1, Row) + assert row1.col1 == "value1" + assert row1.col2 == 1 + assert row1.col3 is True + assert result_set_with_data._next_row_index == 1 + + row2 = result_set_with_data.fetchone() + assert isinstance(row2, Row) + assert row2.col1 == "value2" + assert row2.col2 == 2 + assert row2.col3 is False + assert result_set_with_data._next_row_index == 2 + # Fetch the rest + result_set_with_data.fetchall() + + # Test fetching when no more rows + row_none = result_set_with_data.fetchone() + assert row_none is None + + def test_fetchmany(self, result_set_with_data): + """Test the fetchmany method.""" + # Test fetching multiple rows + rows = result_set_with_data.fetchmany(2) + assert len(rows) == 2 + assert isinstance(rows[0], Row) + assert rows[0].col1 == "value1" + assert rows[0].col2 == 1 + assert rows[0].col3 is True + assert rows[1].col1 == "value2" + assert rows[1].col2 == 2 + assert rows[1].col3 is False + assert result_set_with_data._next_row_index == 2 + + # Test with invalid size with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" + ValueError, match="size argument for fetchmany is -1 but must be >= 0" ): - # Test using the result set in a for loop - for row in result_set: - pass + result_set_with_data.fetchmany(-1) + + def test_fetchall(self, result_set_with_data, sample_data): + """Test the fetchall method.""" + # Test fetching all rows + rows = result_set_with_data.fetchall() + assert len(rows) == len(sample_data) + assert isinstance(rows[0], Row) + assert rows[0].col1 == "value1" + assert rows[0].col2 == 1 + assert rows[0].col3 is True + assert result_set_with_data._next_row_index == len(sample_data) - def test_fill_results_buffer_not_implemented( + # Test fetching again (should return empty) + rows = result_set_with_data.fetchall() + assert len(rows) == 0 + + def test_iteration(self, result_set_with_data, sample_data): + """Test iterating over the result set.""" + # Test iteration + rows = list(result_set_with_data) + assert len(rows) == len(sample_data) + assert isinstance(rows[0], Row) + assert rows[0].col1 == "value1" + assert rows[0].col2 == 1 + assert rows[0].col3 is True + + def test_is_staging_operation( self, mock_connection, mock_sea_client, execute_response ): - """Test that _fill_results_buffer raises NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) + """Test the is_staging_operation property.""" + # Set is_staging_operation to True + execute_response.is_staging_operation = True - with pytest.raises( - NotImplementedError, - match="_fill_results_buffer is not implemented for SEA backend", + with patch( + "databricks.sql.backend.sea.queue.SeaResultSetQueueFactory.build_queue" ): - result_set._fill_results_buffer() + # Create a result set + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + result_data=ResultData(data=[]), + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), + buffer_size_bytes=1000, + arraysize=100, + ) + + # Test the property + assert result_set.is_staging_operation is True + + # Edge case tests + @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") + def test_fetchone_empty_arrow_queue(self, result_set_with_arrow_queue): + """Test fetchone with an empty Arrow queue.""" + # Setup _convert_arrow_table to return empty list + result_set_with_arrow_queue._convert_arrow_table = Mock(return_value=[]) + + # Call fetchone + result = result_set_with_arrow_queue.fetchone() + + # Verify result is None + assert result is None + + # Verify _convert_arrow_table was called + result_set_with_arrow_queue._convert_arrow_table.assert_called_once() + + def test_fetchone_empty_json_queue(self, result_set_with_json_queue): + """Test fetchone with an empty JSON queue.""" + # Setup _create_json_table to return empty list + result_set_with_json_queue._create_json_table = Mock(return_value=[]) + + # Call fetchone + result = result_set_with_json_queue.fetchone() + + # Verify result is None + assert result is None + + # Verify _create_json_table was called + result_set_with_json_queue._create_json_table.assert_called_once() + + @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") + def test_fetchmany_empty_arrow_queue(self, result_set_with_arrow_queue): + """Test fetchmany with an empty Arrow queue.""" + # Setup _convert_arrow_table to return empty list + result_set_with_arrow_queue._convert_arrow_table = Mock(return_value=[]) + + # Call fetchmany + result = result_set_with_arrow_queue.fetchmany(10) + + # Verify result is an empty list + assert result == [] + + # Verify _convert_arrow_table was called + result_set_with_arrow_queue._convert_arrow_table.assert_called_once() + + @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") + def test_fetchall_empty_arrow_queue(self, result_set_with_arrow_queue): + """Test fetchall with an empty Arrow queue.""" + # Setup _convert_arrow_table to return empty list + result_set_with_arrow_queue._convert_arrow_table = Mock(return_value=[]) + + # Call fetchall + result = result_set_with_arrow_queue.fetchall() + + # Verify result is an empty list + assert result == [] + + # Verify _convert_arrow_table was called + result_set_with_arrow_queue._convert_arrow_table.assert_called_once() + + @patch("databricks.sql.backend.sea.utils.conversion.SqlTypeConverter.convert_value") + def test_convert_json_types_with_errors( + self, mock_convert_value, result_set_with_data + ): + """Test error handling in _convert_json_types.""" + # Mock the conversion to fail for the second and third values + mock_convert_value.side_effect = [ + "value1", # First value converts normally + Exception("Invalid int"), # Second value fails + Exception("Invalid boolean"), # Third value fails + ] + + # Data with invalid values + data_row = ["value1", "not_an_int", "not_a_boolean"] + + # Should not raise an exception but log warnings + result = result_set_with_data._convert_json_types(data_row) + + # The first value should be converted normally + assert result[0] == "value1" + + # The invalid values should remain as strings + assert result[1] == "not_an_int" + assert result[2] == "not_a_boolean" + + @patch("databricks.sql.backend.sea.result_set.logger") + @patch("databricks.sql.backend.sea.utils.conversion.SqlTypeConverter.convert_value") + def test_convert_json_types_with_logging( + self, mock_convert_value, mock_logger, result_set_with_data + ): + """Test that errors in _convert_json_types are logged.""" + # Mock the conversion to fail for the second and third values + mock_convert_value.side_effect = [ + "value1", # First value converts normally + Exception("Invalid int"), # Second value fails + Exception("Invalid boolean"), # Third value fails + ] + + # Data with invalid values + data_row = ["value1", "not_an_int", "not_a_boolean"] + + # Call the method + result_set_with_data._convert_json_types(data_row) + + # Verify warnings were logged + assert mock_logger.warning.call_count == 2 diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 4e6e928ab..db765c278 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -8,6 +8,7 @@ NoopTelemetryClient, TelemetryClientFactory, TelemetryHelper, + BaseTelemetryClient, ) from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow from databricks.sql.auth.authenticators import ( @@ -290,7 +291,9 @@ def test_factory_shutdown_flow(self): assert TelemetryClientFactory._initialized is False assert TelemetryClientFactory._executor is None - @patch("databricks.sql.telemetry.telemetry_client.TelemetryClient.export_failure_log") + @patch( + "databricks.sql.telemetry.telemetry_client.TelemetryClient.export_failure_log" + ) @patch("databricks.sql.client.Session") def test_connection_failure_sends_correct_telemetry_payload( self, mock_session, mock_export_failure_log @@ -305,6 +308,7 @@ def test_connection_failure_sends_correct_telemetry_payload( try: from databricks import sql + sql.connect(server_hostname="test-host", http_path="/test-path") except Exception as e: assert str(e) == error_message @@ -312,4 +316,4 @@ def test_connection_failure_sends_correct_telemetry_payload( mock_export_failure_log.assert_called_once() call_arguments = mock_export_failure_log.call_args assert call_arguments[0][0] == "Exception" - assert call_arguments[0][1] == error_message \ No newline at end of file + assert call_arguments[0][1] == error_message diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 1b1a7e380..0cdb43f5c 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -611,7 +611,8 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): self.assertIn("some information about the error", str(cm.exception)) @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + "databricks.sql.utils.ThriftResultSetQueueFactory.build_queue", + return_value=Mock(), ) def test_handle_execute_response_sets_compression_in_direct_results( self, build_queue @@ -730,7 +731,9 @@ def test_get_status_uses_display_message_if_available(self, tcli_service_class): ssl_options=SSLOptions(), ) with self.assertRaises(DatabaseError) as cm: - thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) + thrift_backend.execute_command( + Mock(), Mock(), 100, 100, Mock(), Mock(), Mock() + ) self.assertEqual(display_message, str(cm.exception)) self.assertIn(diagnostic_info, str(cm.exception.message_with_context())) @@ -771,7 +774,9 @@ def test_direct_results_uses_display_message_if_available(self, tcli_service_cla ssl_options=SSLOptions(), ) with self.assertRaises(DatabaseError) as cm: - thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) + thrift_backend.execute_command( + Mock(), Mock(), 100, 100, Mock(), Mock(), Mock() + ) self.assertEqual(display_message, str(cm.exception)) self.assertIn(diagnostic_info, str(cm.exception.message_with_context())) @@ -1004,16 +1009,17 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): ) @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + "databricks.sql.utils.ThriftResultSetQueueFactory.build_queue", + return_value=Mock(), ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_direct_results( self, tcli_service_class, build_queue ): - for is_direct_results, resp_type in itertools.product( + for has_more_rows, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): + with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = Mock() results_mock.startRowOffset = 0 @@ -1025,7 +1031,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( resultSetMetadata=self.metadata_resp, resultSet=ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=is_direct_results, + hasMoreRows=has_more_rows, results=results_mock, ), closeOperation=Mock(), @@ -1046,19 +1052,20 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( has_more_rows_result, ) = thrift_backend._handle_execute_response(execute_resp, Mock()) - self.assertEqual(is_direct_results, has_more_rows_result) + self.assertEqual(has_more_rows, has_more_rows_result) @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + "databricks.sql.utils.ThriftResultSetQueueFactory.build_queue", + return_value=Mock(), ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_result_response( self, tcli_service_class, build_queue ): - for is_direct_results, resp_type in itertools.product( + for has_more_rows, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): + with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = MagicMock() results_mock.startRowOffset = 0 @@ -1071,7 +1078,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( fetch_results_resp = ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=is_direct_results, + hasMoreRows=has_more_rows, results=results_mock, resultSetMetadata=ttypes.TGetResultSetMetadataResp( resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET @@ -1094,7 +1101,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( thrift_backend = self._make_fake_thrift_backend() thrift_backend._handle_execute_response(execute_resp, Mock()) - _, has_more_rows_resp = thrift_backend.fetch_results( + _, has_more_rows_resp, _ = thrift_backend.fetch_results( command_id=Mock(), max_rows=1, max_bytes=1, @@ -1102,9 +1109,10 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( lz4_compressed=False, arrow_schema_bytes=Mock(), description=Mock(), + chunk_id=0, ) - self.assertEqual(is_direct_results, has_more_rows_resp) + self.assertEqual(has_more_rows, has_more_rows_resp) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_arrow_batches_row_count_are_respected(self, tcli_service_class): @@ -1147,7 +1155,7 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - arrow_queue, has_more_results = thrift_backend.fetch_results( + arrow_queue, has_more_results, _ = thrift_backend.fetch_results( command_id=Mock(), max_rows=1, max_bytes=1, @@ -1155,6 +1163,7 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): lz4_compressed=False, arrow_schema_bytes=schema, description=MagicMock(), + chunk_id=0, ) self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) @@ -1180,7 +1189,7 @@ def test_execute_statement_calls_client_and_handle_execute_response( cursor_mock = Mock() result = thrift_backend.execute_command( - "foo", Mock(), 100, 200, Mock(), cursor_mock + "foo", Mock(), 100, 200, Mock(), cursor_mock, Mock() ) # Verify the result is a ResultSet self.assertEqual(result, mock_result_set.return_value) @@ -1445,7 +1454,9 @@ def test_non_arrow_non_column_based_set_triggers_exception( thrift_backend = self._make_fake_thrift_backend() with self.assertRaises(OperationalError) as cm: - thrift_backend.execute_command("foo", Mock(), 100, 100, Mock(), Mock()) + thrift_backend.execute_command( + "foo", Mock(), 100, 100, Mock(), Mock(), Mock() + ) self.assertIn( "Expected results to be in Arrow or column based format", str(cm.exception) ) @@ -2274,7 +2285,9 @@ def test_execute_command_sets_complex_type_fields_correctly( ssl_options=SSLOptions(), **complex_arg_types, ) - thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) + thrift_backend.execute_command( + Mock(), Mock(), 100, 100, Mock(), Mock(), Mock() + ) t_execute_statement_req = tcli_service_instance.ExecuteStatement.call_args[ 0 ][0]