Skip to content

Download Manager: Stop shutdown in case of empty download tasks Queue #641

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: sea-migration
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/databricks/sql/backend/sea/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,14 +363,14 @@ def __init__(
# Initialize table and position
self.table = self._create_next_table()

def _create_next_table(self) -> Union["pyarrow.Table", None]:
def _create_next_table(self) -> "pyarrow.Table":
"""Create next table by retrieving the logical next downloaded file."""
if self.link_fetcher is None:
return None
return self._create_empty_table()

chunk_link = self.link_fetcher.get_chunk_link(self._current_chunk_index)
if chunk_link is None:
return None
return self._create_empty_table()

row_offset = chunk_link.row_offset
# NOTE: link has already been submitted to download manager at this point
Expand Down
28 changes: 18 additions & 10 deletions src/databricks/sql/cloudfetch/download_manager.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import logging

from concurrent.futures import ThreadPoolExecutor, Future
import threading
from typing import List, Union, Tuple, Optional

from databricks.sql.cloudfetch.downloader import (
ResultSetDownloadHandler,
DownloadableResultSettings,
DownloadedFile,
)
from databricks.sql.exc import Error
from databricks.sql.types import SSLOptions
from databricks.sql.telemetry.models.event import StatementType
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
Expand Down Expand Up @@ -39,26 +41,24 @@ def __init__(
self._pending_links.append((i, link))
self.chunk_id += len(links)

self._download_tasks: List[Future[DownloadedFile]] = []
self._max_download_threads: int = max_download_threads

self._download_condition = threading.Condition()
self._download_tasks: List[Future[DownloadedFile]] = []
self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads)

self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
self._ssl_options = ssl_options
self.session_id_hex = session_id_hex
self.statement_id = statement_id

def get_next_downloaded_file(
self, next_row_offset: int
) -> Union[DownloadedFile, None]:
def get_next_downloaded_file(self, next_row_offset: int) -> DownloadedFile:
"""
Get next file that starts at given offset.

This function gets the next downloaded file in which its rows start at the specified next_row_offset
in relation to the full result. File downloads are scheduled if not already, and once the correct
download handler is located, the function waits for the download status and returns the resulting file.
If there are no more downloads, a download was not successful, or the correct file could not be located,
this function shuts down the thread pool and returns None.

Args:
next_row_offset (int): The offset of the starting row of the next file we want data from.
Expand All @@ -67,10 +67,11 @@ def get_next_downloaded_file(
# Make sure the download queue is always full
self._schedule_downloads()

# No more files to download from this batch of links
if len(self._download_tasks) == 0:
self._shutdown_manager()
return None
while len(self._download_tasks) == 0:
if self._thread_pool._shutdown:
raise Error("download manager shut down before file was ready")
with self._download_condition:
self._download_condition.wait()

task = self._download_tasks.pop(0)
# Future's `result()` method will wait for the call to complete, and return
Expand Down Expand Up @@ -113,6 +114,9 @@ def _schedule_downloads(self):
task = self._thread_pool.submit(handler.run)
self._download_tasks.append(task)

with self._download_condition:
self._download_condition.notify_all()

def add_link(self, link: TSparkArrowResultLink):
"""
Add more links to the download manager.
Expand All @@ -132,8 +136,12 @@ def add_link(self, link: TSparkArrowResultLink):
self._pending_links.append((self.chunk_id, link))
self.chunk_id += 1

self._schedule_downloads()

def _shutdown_manager(self):
# Clear download handlers and shutdown the thread pool
self._pending_links = []
self._download_tasks = []
self._thread_pool.shutdown(wait=False)
with self._download_condition:
self._download_condition.notify_all()
56 changes: 25 additions & 31 deletions src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def __init__(
self.chunk_id = chunk_id

# Table state
self.table = None
self.table = self._create_empty_table()
self.table_row_index = 0

# Initialize download manager
Expand All @@ -272,23 +272,20 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
Returns:
pyarrow.Table
"""
if not self.table:
logger.debug("CloudFetchQueue: no more rows available")
# Return empty pyarrow table to cause retry of fetch
return self._create_empty_table()

logger.debug("CloudFetchQueue: trying to get {} next rows".format(num_rows))
results = self.table.slice(0, 0)
while num_rows > 0 and self.table:
while num_rows > 0 and self.table.num_rows > 0:
# Replace current table with the next table if we are at the end of the current table
if self.table_row_index == self.table.num_rows:
self.table = self._create_next_table()
self.table_row_index = 0

# Get remaining of num_rows or the rest of the current table, whichever is smaller
length = min(num_rows, self.table.num_rows - self.table_row_index)
table_slice = self.table.slice(self.table_row_index, length)
results = pyarrow.concat_tables([results, table_slice])
self.table_row_index += table_slice.num_rows

# Replace current table with the next table if we are at the end of the current table
if self.table_row_index == self.table.num_rows:
self.table = self._create_next_table()
self.table_row_index = 0
num_rows -= table_slice.num_rows

logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows))
Expand All @@ -302,11 +299,8 @@ def remaining_rows(self) -> "pyarrow.Table":
pyarrow.Table
"""

if not self.table:
# Return empty pyarrow table to cause retry of fetch
return self._create_empty_table()
results = self.table.slice(0, 0)
while self.table:
while self.table.num_rows > 0:
table_slice = self.table.slice(
self.table_row_index, self.table.num_rows - self.table_row_index
)
Expand All @@ -316,17 +310,11 @@ def remaining_rows(self) -> "pyarrow.Table":
self.table_row_index = 0
return results

def _create_table_at_offset(self, offset: int) -> Union["pyarrow.Table", None]:
def _create_table_at_offset(self, offset: int) -> "pyarrow.Table":
"""Create next table at the given row offset"""

# Create next table by retrieving the logical next downloaded file, or return None to signal end of queue
downloaded_file = self.download_manager.get_next_downloaded_file(offset)
if not downloaded_file:
logger.debug(
"CloudFetchQueue: Cannot find downloaded file for row {}".format(offset)
)
# None signals no more Arrow tables can be built from the remaining handlers if any remain
return None
arrow_table = create_arrow_table_from_arrow_file(
downloaded_file.file_bytes, self.description
)
Expand All @@ -342,7 +330,7 @@ def _create_table_at_offset(self, offset: int) -> Union["pyarrow.Table", None]:
return arrow_table

@abstractmethod
def _create_next_table(self) -> Union["pyarrow.Table", None]:
def _create_next_table(self) -> "pyarrow.Table":
"""Create next table by retrieving the logical next downloaded file."""
pass

Expand All @@ -361,7 +349,7 @@ class ThriftCloudFetchQueue(CloudFetchQueue):

def __init__(
self,
schema_bytes,
schema_bytes: Optional[bytes],
max_download_threads: int,
ssl_options: SSLOptions,
session_id_hex: Optional[str],
Expand Down Expand Up @@ -406,6 +394,9 @@ def __init__(
start_row_offset
)
)

self.num_links_downloaded = 0

if self.result_links:
for result_link in self.result_links:
logger.debug(
Expand All @@ -418,20 +409,23 @@ def __init__(
# Initialize table and position
self.table = self._create_next_table()

def _create_next_table(self) -> Union["pyarrow.Table", None]:
def _create_next_table(self) -> "pyarrow.Table":
if self.num_links_downloaded >= len(self.result_links):
return self._create_empty_table()

logger.debug(
"ThriftCloudFetchQueue: Trying to get downloaded file for row {}".format(
self.start_row_index
)
)
arrow_table = self._create_table_at_offset(self.start_row_index)
if arrow_table:
self.start_row_index += arrow_table.num_rows
logger.debug(
"ThriftCloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}".format(
arrow_table.num_rows, self.start_row_index
)
self.num_links_downloaded += 1
self.start_row_index += arrow_table.num_rows
logger.debug(
"ThriftCloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}".format(
arrow_table.num_rows, self.start_row_index
)
)
return arrow_table


Expand Down
Loading
Loading