diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py index 130f0c5bf..db430a525 100644 --- a/src/databricks/sql/backend/sea/queue.py +++ b/src/databricks/sql/backend/sea/queue.py @@ -364,14 +364,14 @@ def __init__( # Initialize table and position self.table = self._create_next_table() - def _create_next_table(self) -> Union["pyarrow.Table", None]: + def _create_next_table(self) -> "pyarrow.Table": """Create next table by retrieving the logical next downloaded file.""" if self.link_fetcher is None: - return None + return self._create_empty_table() chunk_link = self.link_fetcher.get_chunk_link(self._current_chunk_index) if chunk_link is None: - return None + return self._create_empty_table() row_offset = chunk_link.row_offset # NOTE: link has already been submitted to download manager at this point diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 32b698bed..e187771f7 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -1,6 +1,7 @@ import logging from concurrent.futures import ThreadPoolExecutor, Future +import threading from typing import List, Union, Tuple, Optional from databricks.sql.cloudfetch.downloader import ( @@ -8,6 +9,7 @@ DownloadableResultSettings, DownloadedFile, ) +from databricks.sql.exc import Error from databricks.sql.types import SSLOptions from databricks.sql.telemetry.models.event import StatementType from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink @@ -39,8 +41,10 @@ def __init__( self._pending_links.append((i, link)) self.chunk_id += len(links) - self._download_tasks: List[Future[DownloadedFile]] = [] self._max_download_threads: int = max_download_threads + + self._download_condition = threading.Condition() + self._download_tasks: List[Future[DownloadedFile]] = [] self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads) self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed) @@ -48,17 +52,13 @@ def __init__( self.session_id_hex = session_id_hex self.statement_id = statement_id - def get_next_downloaded_file( - self, next_row_offset: int - ) -> Union[DownloadedFile, None]: + def get_next_downloaded_file(self, next_row_offset: int) -> DownloadedFile: """ Get next file that starts at given offset. This function gets the next downloaded file in which its rows start at the specified next_row_offset in relation to the full result. File downloads are scheduled if not already, and once the correct download handler is located, the function waits for the download status and returns the resulting file. - If there are no more downloads, a download was not successful, or the correct file could not be located, - this function shuts down the thread pool and returns None. Args: next_row_offset (int): The offset of the starting row of the next file we want data from. @@ -67,10 +67,11 @@ def get_next_downloaded_file( # Make sure the download queue is always full self._schedule_downloads() - # No more files to download from this batch of links - if len(self._download_tasks) == 0: - self._shutdown_manager() - return None + while len(self._download_tasks) == 0: + if self._thread_pool._shutdown: + raise Error("download manager shut down before file was ready") + with self._download_condition: + self._download_condition.wait() task = self._download_tasks.pop(0) # Future's `result()` method will wait for the call to complete, and return @@ -113,6 +114,9 @@ def _schedule_downloads(self): task = self._thread_pool.submit(handler.run) self._download_tasks.append(task) + with self._download_condition: + self._download_condition.notify_all() + def add_link(self, link: TSparkArrowResultLink): """ Add more links to the download manager. @@ -132,8 +136,12 @@ def add_link(self, link: TSparkArrowResultLink): self._pending_links.append((self.chunk_id, link)) self.chunk_id += 1 + self._schedule_downloads() + def _shutdown_manager(self): # Clear download handlers and shutdown the thread pool self._pending_links = [] self._download_tasks = [] self._thread_pool.shutdown(wait=False) + with self._download_condition: + self._download_condition.notify_all() diff --git a/src/databricks/sql/types.py b/src/databricks/sql/types.py index e188ef577..37a93ae13 100644 --- a/src/databricks/sql/types.py +++ b/src/databricks/sql/types.py @@ -229,6 +229,7 @@ def __reduce__( self, ) -> Union[str, Tuple[Any, ...]]: """Returns a tuple so Python knows how to pickle Row.""" + if hasattr(self, "__fields__"): return (_create_row, (self.__fields__, tuple(self))) else: @@ -236,6 +237,7 @@ def __reduce__( def __repr__(self) -> str: """Printable representation of Row used in Python REPL.""" + if hasattr(self, "__fields__"): return "Row(%s)" % ", ".join( "%s=%r" % (k, v) for k, v in zip(self.__fields__, tuple(self)) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 4617f7de6..dea187cea 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -249,7 +249,7 @@ def __init__( self.chunk_id = chunk_id # Table state - self.table = None + self.table = self._create_empty_table() self.table_row_index = 0 # Initialize download manager @@ -273,24 +273,20 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table": pyarrow.Table """ - if not self.table: - logger.debug("CloudFetchQueue: no more rows available") - # Return empty pyarrow table to cause retry of fetch - return self._create_empty_table() logger.debug("CloudFetchQueue: trying to get {} next rows".format(num_rows)) results = self.table.slice(0, 0) partial_result_chunks = [results] - while num_rows > 0 and self.table: + while num_rows > 0 and self.table.num_rows > 0: + # Replace current table with the next table if we are at the end of the current table + if self.table_row_index == self.table.num_rows: + self.table = self._create_next_table() + self.table_row_index = 0 + # Get remaining of num_rows or the rest of the current table, whichever is smaller length = min(num_rows, self.table.num_rows - self.table_row_index) table_slice = self.table.slice(self.table_row_index, length) partial_result_chunks.append(table_slice) self.table_row_index += table_slice.num_rows - - # Replace current table with the next table if we are at the end of the current table - if self.table_row_index == self.table.num_rows: - self.table = self._create_next_table() - self.table_row_index = 0 num_rows -= table_slice.num_rows logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows)) @@ -304,12 +300,9 @@ def remaining_rows(self) -> "pyarrow.Table": pyarrow.Table """ - if not self.table: - # Return empty pyarrow table to cause retry of fetch - return self._create_empty_table() results = self.table.slice(0, 0) partial_result_chunks = [results] - while self.table: + while self.table.num_rows > 0: table_slice = self.table.slice( self.table_row_index, self.table.num_rows - self.table_row_index ) @@ -319,17 +312,11 @@ def remaining_rows(self) -> "pyarrow.Table": self.table_row_index = 0 return pyarrow.concat_tables(partial_result_chunks, use_threads=True) - def _create_table_at_offset(self, offset: int) -> Union["pyarrow.Table", None]: + def _create_table_at_offset(self, offset: int) -> "pyarrow.Table": """Create next table at the given row offset""" # Create next table by retrieving the logical next downloaded file, or return None to signal end of queue downloaded_file = self.download_manager.get_next_downloaded_file(offset) - if not downloaded_file: - logger.debug( - "CloudFetchQueue: Cannot find downloaded file for row {}".format(offset) - ) - # None signals no more Arrow tables can be built from the remaining handlers if any remain - return None arrow_table = create_arrow_table_from_arrow_file( downloaded_file.file_bytes, self.description ) @@ -345,7 +332,7 @@ def _create_table_at_offset(self, offset: int) -> Union["pyarrow.Table", None]: return arrow_table @abstractmethod - def _create_next_table(self) -> Union["pyarrow.Table", None]: + def _create_next_table(self) -> "pyarrow.Table": """Create next table by retrieving the logical next downloaded file.""" pass @@ -364,7 +351,7 @@ class ThriftCloudFetchQueue(CloudFetchQueue): def __init__( self, - schema_bytes, + schema_bytes: Optional[bytes], max_download_threads: int, ssl_options: SSLOptions, session_id_hex: Optional[str], @@ -398,6 +385,8 @@ def __init__( chunk_id=chunk_id, ) + self.num_links_downloaded = 0 + self.start_row_index = start_row_offset self.result_links = result_links or [] self.session_id_hex = session_id_hex @@ -421,20 +410,23 @@ def __init__( # Initialize table and position self.table = self._create_next_table() - def _create_next_table(self) -> Union["pyarrow.Table", None]: + def _create_next_table(self) -> "pyarrow.Table": + if self.num_links_downloaded >= len(self.result_links): + return self._create_empty_table() + logger.debug( "ThriftCloudFetchQueue: Trying to get downloaded file for row {}".format( self.start_row_index ) ) arrow_table = self._create_table_at_offset(self.start_row_index) - if arrow_table: - self.start_row_index += arrow_table.num_rows - logger.debug( - "ThriftCloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}".format( - arrow_table.num_rows, self.start_row_index - ) + self.num_links_downloaded += 1 + self.start_row_index += arrow_table.num_rows + logger.debug( + "ThriftCloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}".format( + arrow_table.num_rows, self.start_row_index ) + ) return arrow_table @@ -740,7 +732,6 @@ def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table": def convert_to_assigned_datatypes_in_column_table(column_table, description): - converted_column_table = [] for i, col in enumerate(column_table): if description[i][1] == "decimal": diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 4271f0d7d..29313ff02 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -105,7 +105,11 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Mock the execute response with controlled state mock_execute_response = Mock(spec=ExecuteResponse) - mock_execute_response.status = initial_state + + mock_execute_response.command_id = Mock(spec=CommandId) + mock_execute_response.status = ( + CommandState.SUCCEEDED if not closed else CommandState.CLOSED + ) mock_execute_response.has_been_closed_server_side = closed mock_execute_response.is_staging_operation = False mock_execute_response.command_id = Mock(spec=CommandId) @@ -262,9 +266,7 @@ def test_negative_fetch_throws_exception(self): mock_backend = Mock() mock_backend.fetch_results.return_value = (Mock(), False, 0) - result_set = ThriftResultSet( - Mock(), Mock(), mock_backend - ) + result_set = ThriftResultSet(Mock(), Mock(), mock_backend) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index faa8e2f99..31450e7fd 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -36,27 +36,30 @@ def create_result_links(self, num_files: int, start_row_offset: int = 0): return result_links @staticmethod - def make_arrow_table(): - batch = [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]] - n_cols = len(batch[0]) if batch else 0 + def make_arrow_table(num_rows: int = 4, num_cols: int = 4): + batch = [[i for i in range(num_cols)] for _ in range(num_rows)] + n_cols = len(batch[0]) if batch else num_cols schema = pyarrow.schema({"col%s" % i: pyarrow.uint32() for i in range(n_cols)}) cols = [[batch[row][col] for row in range(len(batch))] for col in range(n_cols)] return pyarrow.Table.from_pydict(dict(zip(schema.names, cols)), schema=schema) @staticmethod - def get_schema_bytes(): + def get_schema_bytes_and_description(): schema = pyarrow.schema({"col%s" % i: pyarrow.uint32() for i in range(4)}) + description = [ + ("col%s" % i, "int", None, None, None, None, None) for i in range(4) + ] sink = pyarrow.BufferOutputStream() writer = pyarrow.ipc.RecordBatchStreamWriter(sink, schema) writer.close() - return sink.getvalue().to_pybytes() + return sink.getvalue().to_pybytes(), description @patch( "databricks.sql.utils.ThriftCloudFetchQueue._create_next_table", return_value=[None, None], ) def test_initializer_adds_links(self, mock_create_next_table): - schema_bytes = MagicMock() + schema_bytes, description = self.get_schema_bytes_and_description() result_links = self.create_result_links(10) queue = utils.ThriftCloudFetchQueue( schema_bytes, @@ -66,14 +69,18 @@ def test_initializer_adds_links(self, mock_create_next_table): session_id_hex=Mock(), statement_id=Mock(), chunk_id=0, + description=description, ) - assert len(queue.download_manager._pending_links) == 10 - assert len(queue.download_manager._download_tasks) == 0 + assert ( + len(queue.download_manager._pending_links) + + len(queue.download_manager._download_tasks) + == 10 + ) mock_create_next_table.assert_called() def test_initializer_no_links_to_add(self): - schema_bytes = MagicMock() + schema_bytes, description = self.get_schema_bytes_and_description() result_links = [] queue = utils.ThriftCloudFetchQueue( schema_bytes, @@ -83,29 +90,11 @@ def test_initializer_no_links_to_add(self): session_id_hex=Mock(), statement_id=Mock(), chunk_id=0, + description=description, ) assert len(queue.download_manager._pending_links) == 0 assert len(queue.download_manager._download_tasks) == 0 - assert queue.table is None - - @patch( - "databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", - return_value=None, - ) - def test_create_next_table_no_download(self, mock_get_next_downloaded_file): - queue = utils.ThriftCloudFetchQueue( - MagicMock(), - result_links=[], - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) - - assert queue._create_next_table() is None - mock_get_next_downloaded_file.assert_called_with(0) @patch("databricks.sql.utils.create_arrow_table_from_arrow_file") @patch( @@ -116,10 +105,23 @@ def test_initializer_create_next_table_success( self, mock_get_next_downloaded_file, mock_create_arrow_table ): mock_create_arrow_table.return_value = self.make_arrow_table() - schema_bytes, description = MagicMock(), MagicMock() + schema_bytes, description = self.get_schema_bytes_and_description() queue = utils.ThriftCloudFetchQueue( schema_bytes, - result_links=[], + result_links=[ + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=0, + rowCount=4, + bytesNum=10, + ), + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=4, + rowCount=4, + bytesNum=10, + ), + ], description=description, max_download_threads=10, ssl_options=SSLOptions(), @@ -144,10 +146,17 @@ def test_initializer_create_next_table_success( @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") def test_next_n_rows_0_rows(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() - schema_bytes, description = MagicMock(), MagicMock() + schema_bytes, description = self.get_schema_bytes_and_description() queue = utils.ThriftCloudFetchQueue( schema_bytes, - result_links=[], + result_links=[ + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=0, + rowCount=4, + bytesNum=10, + ) + ], description=description, max_download_threads=10, ssl_options=SSLOptions(), @@ -166,10 +175,17 @@ def test_next_n_rows_0_rows(self, mock_create_next_table): @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") def test_next_n_rows_partial_table(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() - schema_bytes, description = MagicMock(), MagicMock() + schema_bytes, description = self.get_schema_bytes_and_description() queue = utils.ThriftCloudFetchQueue( schema_bytes, - result_links=[], + result_links=[ + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=0, + rowCount=4, + bytesNum=10, + ) + ], description=description, max_download_threads=10, ssl_options=SSLOptions(), @@ -189,10 +205,17 @@ def test_next_n_rows_partial_table(self, mock_create_next_table): @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") def test_next_n_rows_more_than_one_table(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() - schema_bytes, description = MagicMock(), MagicMock() + schema_bytes, description = self.get_schema_bytes_and_description() queue = utils.ThriftCloudFetchQueue( schema_bytes, - result_links=[], + result_links=[ + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=0, + rowCount=4, + bytesNum=10, + ) + ], description=description, max_download_threads=10, ssl_options=SSLOptions(), @@ -216,11 +239,21 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): - mock_create_next_table.side_effect = [self.make_arrow_table(), None] - schema_bytes, description = MagicMock(), MagicMock() + mock_create_next_table.side_effect = [ + self.make_arrow_table(), + self.make_arrow_table(0), + ] + schema_bytes, description = self.get_schema_bytes_and_description() queue = utils.ThriftCloudFetchQueue( schema_bytes, - result_links=[], + result_links=[ + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=0, + rowCount=4, + bytesNum=10, + ) + ], description=description, max_download_threads=10, ssl_options=SSLOptions(), @@ -241,11 +274,19 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): return_value=None, ) def test_next_n_rows_empty_table(self, mock_create_next_table): - schema_bytes = self.get_schema_bytes() - description = MagicMock() + mock_create_next_table.side_effect = [self.make_arrow_table(0)] + + schema_bytes, description = self.get_schema_bytes_and_description() queue = utils.ThriftCloudFetchQueue( schema_bytes, - result_links=[], + result_links=[ + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=0, + rowCount=4, + bytesNum=10, + ) + ], description=description, max_download_threads=10, ssl_options=SSLOptions(), @@ -253,7 +294,8 @@ def test_next_n_rows_empty_table(self, mock_create_next_table): statement_id=Mock(), chunk_id=0, ) - assert queue.table is None + + assert queue.table == self.make_arrow_table(0) result = queue.next_n_rows(100) mock_create_next_table.assert_called() @@ -261,11 +303,21 @@ def test_next_n_rows_empty_table(self, mock_create_next_table): @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table): - mock_create_next_table.side_effect = [self.make_arrow_table(), None, 0] - schema_bytes, description = MagicMock(), MagicMock() + mock_create_next_table.side_effect = [ + self.make_arrow_table(), + self.make_arrow_table(0), + ] + schema_bytes, description = self.get_schema_bytes_and_description() queue = utils.ThriftCloudFetchQueue( schema_bytes, - result_links=[], + result_links=[ + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=0, + rowCount=4, + bytesNum=10, + ) + ], description=description, max_download_threads=10, ssl_options=SSLOptions(), @@ -283,11 +335,21 @@ def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table) @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_table): - mock_create_next_table.side_effect = [self.make_arrow_table(), None] - schema_bytes, description = MagicMock(), MagicMock() + mock_create_next_table.side_effect = [ + self.make_arrow_table(), + self.make_arrow_table(0), + ] + schema_bytes, description = self.get_schema_bytes_and_description() queue = utils.ThriftCloudFetchQueue( schema_bytes, - result_links=[], + result_links=[ + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=0, + rowCount=4, + bytesNum=10, + ) + ], description=description, max_download_threads=10, ssl_options=SSLOptions(), @@ -305,11 +367,21 @@ def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_tabl @patch("databricks.sql.utils.ThriftCloudFetchQueue._create_next_table") def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): - mock_create_next_table.side_effect = [self.make_arrow_table(), None] - schema_bytes, description = MagicMock(), MagicMock() + mock_create_next_table.side_effect = [ + self.make_arrow_table(), + self.make_arrow_table(0), + ] + schema_bytes, description = self.get_schema_bytes_and_description() queue = utils.ThriftCloudFetchQueue( schema_bytes, - result_links=[], + result_links=[ + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=0, + rowCount=4, + bytesNum=10, + ) + ], description=description, max_download_threads=10, ssl_options=SSLOptions(), @@ -332,12 +404,19 @@ def test_remaining_rows_multiple_tables_fully_returned( mock_create_next_table.side_effect = [ self.make_arrow_table(), self.make_arrow_table(), - None, + self.make_arrow_table(0), ] - schema_bytes, description = MagicMock(), MagicMock() + schema_bytes, description = self.get_schema_bytes_and_description() queue = utils.ThriftCloudFetchQueue( schema_bytes, - result_links=[], + result_links=[ + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=0, + rowCount=4, + bytesNum=10, + ) + ], description=description, max_download_threads=10, ssl_options=SSLOptions(), @@ -364,11 +443,19 @@ def test_remaining_rows_multiple_tables_fully_returned( return_value=None, ) def test_remaining_rows_empty_table(self, mock_create_next_table): - schema_bytes = self.get_schema_bytes() - description = MagicMock() + mock_create_next_table.side_effect = [self.make_arrow_table(0)] + + schema_bytes, description = self.get_schema_bytes_and_description() queue = utils.ThriftCloudFetchQueue( schema_bytes, - result_links=[], + result_links=[ + TSparkArrowResultLink( + fileLink="fileLink", + startRowOffset=0, + rowCount=4, + bytesNum=10, + ) + ], description=description, max_download_threads=10, ssl_options=SSLOptions(), @@ -376,7 +463,7 @@ def test_remaining_rows_empty_table(self, mock_create_next_table): statement_id=Mock(), chunk_id=0, ) - assert queue.table is None + assert queue.table == self.make_arrow_table(0) result = queue.remaining_rows() assert result == pyarrow.ipc.open_stream(bytearray(schema_bytes)).read_all() diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index a7cd92a51..5f1386a57 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -73,11 +73,15 @@ def test_run_get_response_not_ok(self, mock_time): settings.use_proxy = False result_link = Mock(expiryTime=1001) - with patch.object( - http_client, - "execute", - return_value=create_response(status_code=404, _content=b"1234"), - ): + # Create a mock response with 404 status + mock_response = create_response(status_code=404, _content=b"Not Found") + mock_response.raise_for_status = Mock( + side_effect=requests.exceptions.HTTPError("404") + ) + + with patch.object(http_client, "execute") as mock_execute: + mock_execute.return_value.__enter__.return_value = mock_response + d = downloader.ResultSetDownloadHandler( settings, result_link, diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py index cbeae098b..f0dcf5297 100644 --- a/tests/unit/test_sea_queue.py +++ b/tests/unit/test_sea_queue.py @@ -27,6 +27,11 @@ import threading import time +try: + import pyarrow as pa +except ImportError: + pa = None + class TestJsonQueue: """Test suite for the JsonQueue class.""" @@ -199,6 +204,7 @@ def test_build_queue_json_array(self, json_manifest, sample_data): assert isinstance(queue, JsonQueue) assert queue.data_array == sample_data + @pytest.mark.skipif(pa is None, reason="pyarrow not installed") def test_build_queue_arrow_stream( self, arrow_manifest, ssl_options, mock_sea_client, description ): @@ -328,6 +334,7 @@ def test_convert_to_thrift_link_no_headers(self, sample_external_link_no_headers @patch("databricks.sql.backend.sea.queue.ResultFileDownloadManager") @patch("databricks.sql.backend.sea.queue.logger") + @pytest.mark.skipif(pa is None, reason="pyarrow not installed") def test_init_with_valid_initial_link( self, mock_logger, @@ -357,6 +364,7 @@ def test_init_with_valid_initial_link( @patch("databricks.sql.backend.sea.queue.ResultFileDownloadManager") @patch("databricks.sql.backend.sea.queue.logger") + @pytest.mark.skipif(pa is None, reason="pyarrow not installed") def test_init_no_initial_links( self, mock_logger, @@ -377,7 +385,7 @@ def test_init_no_initial_links( lz4_compressed=False, description=description, ) - assert queue.table is None + assert queue.table == pa.Table.from_pydict({}) @patch("databricks.sql.backend.sea.queue.logger") def test_create_next_table_success(self, mock_logger): @@ -481,6 +489,7 @@ def test_hybrid_disposition_with_attachment( @patch("databricks.sql.backend.sea.queue.ResultFileDownloadManager") @patch.object(SeaCloudFetchQueue, "_create_next_table", return_value=None) + @pytest.mark.skipif(pa is None, reason="pyarrow not installed") def test_hybrid_disposition_with_external_links( self, mock_create_table, diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 398387540..c54fc1c72 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -289,30 +289,3 @@ def test_factory_shutdown_flow(self): TelemetryClientFactory.close(session2) assert TelemetryClientFactory._initialized is False assert TelemetryClientFactory._executor is None - - @patch( - "databricks.sql.telemetry.telemetry_client.TelemetryClient.export_failure_log" - ) - @patch("databricks.sql.client.Session") - def test_connection_failure_sends_correct_telemetry_payload( - self, mock_session, mock_export_failure_log - ): - """ - Verify that a connection failure constructs and sends the correct - telemetry payload via _send_telemetry. - """ - - error_message = "Could not connect to host" - mock_session.side_effect = Exception(error_message) - - try: - from databricks import sql - - sql.connect(server_hostname="test-host", http_path="/test-path") - except Exception as e: - assert str(e) == error_message - - mock_export_failure_log.assert_called_once() - call_arguments = mock_export_failure_log.call_args - assert call_arguments[0][0] == "Exception" - assert call_arguments[0][1] == error_message diff --git a/tests/unit/test_telemetry_retry.py b/tests/unit/test_telemetry_retry.py index 11055b558..94137c5b1 100644 --- a/tests/unit/test_telemetry_retry.py +++ b/tests/unit/test_telemetry_retry.py @@ -6,7 +6,8 @@ from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory from databricks.sql.auth.retry import DatabricksRetryPolicy -PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn' +PATCH_TARGET = "urllib3.connectionpool.HTTPSConnectionPool._get_conn" + def create_mock_conn(responses): """Creates a mock connection object whose getresponse() method yields a series of responses.""" @@ -16,15 +17,18 @@ def create_mock_conn(responses): mock_http_response = MagicMock() mock_http_response.status = resp.get("status") mock_http_response.headers = resp.get("headers", {}) - body = resp.get("body", b'{}') + body = resp.get("body", b"{}") mock_http_response.fp = io.BytesIO(body) + def release(): mock_http_response.fp.close() + mock_http_response.release_conn = release mock_http_responses.append(mock_http_response) mock_conn.getresponse.side_effect = mock_http_responses return mock_conn + class TestTelemetryClientRetries: @pytest.fixture(autouse=True) def setup_and_teardown(self): @@ -49,28 +53,28 @@ def get_client(self, session_id, num_retries=3): host_url="test.databricks.com", ) client = TelemetryClientFactory.get_telemetry_client(session_id) - + retry_policy = DatabricksRetryPolicy( delay_min=0.01, delay_max=0.02, stop_after_attempts_duration=2.0, - stop_after_attempts_count=num_retries, + stop_after_attempts_count=num_retries, delay_default=0.1, force_dangerous_codes=[], - urllib3_kwargs={'total': num_retries} + urllib3_kwargs={"total": num_retries}, ) adapter = client._http_client.session.adapters.get("https://") adapter.max_retries = retry_policy return client @pytest.mark.parametrize( - "status_code, description", - [ - (401, "Unauthorized"), - (403, "Forbidden"), - (501, "Not Implemented"), - (200, "Success"), - ], + "status_code, description", + [ + (401, "Unauthorized"), + (403, "Forbidden"), + (501, "Not Implemented"), + (200, "Success"), + ], ) def test_non_retryable_status_codes_are_not_retried(self, status_code, description): """ @@ -80,7 +84,9 @@ def test_non_retryable_status_codes_are_not_retried(self, status_code, descripti client = self.get_client(f"session-{status_code}") mock_responses = [{"status": status_code}] - with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: + with patch( + PATCH_TARGET, return_value=create_mock_conn(mock_responses) + ) as mock_get_conn: client.export_failure_log("TestError", "Test message") TelemetryClientFactory.close(client._session_id_hex) @@ -92,16 +98,26 @@ def test_exceeds_retry_count_limit(self): Verifies that the client respects the Retry-After header and retries on 429, 502, 503. """ num_retries = 3 - expected_total_calls = num_retries + 1 + expected_total_calls = num_retries + 1 retry_after = 1 client = self.get_client("session-exceed-limit", num_retries=num_retries) - mock_responses = [{"status": 503, "headers": {"Retry-After": str(retry_after)}}, {"status": 429}, {"status": 502}, {"status": 503}] - - with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: + mock_responses = [ + {"status": 503, "headers": {"Retry-After": str(retry_after)}}, + {"status": 429}, + {"status": 502}, + {"status": 503}, + ] + + with patch( + PATCH_TARGET, return_value=create_mock_conn(mock_responses) + ) as mock_get_conn: start_time = time.time() client.export_failure_log("TestError", "Test message") TelemetryClientFactory.close(client._session_id_hex) end_time = time.time() - - assert mock_get_conn.return_value.getresponse.call_count == expected_total_calls - assert end_time - start_time > retry_after \ No newline at end of file + + assert ( + mock_get_conn.return_value.getresponse.call_count + == expected_total_calls + ) + assert end_time - start_time > retry_after