diff --git a/src/together/lib/constants.py b/src/together/lib/constants.py index ed64b08e..bae83c64 100644 --- a/src/together/lib/constants.py +++ b/src/together/lib/constants.py @@ -14,6 +14,9 @@ # Download defaults DOWNLOAD_BLOCK_SIZE = 10 * 1024 * 1024 # 10 MB DISABLE_TQDM = False +MAX_DOWNLOAD_RETRIES = 5 # Maximum retries for download failures +DOWNLOAD_INITIAL_RETRY_DELAY = 1.0 # Initial retry delay in seconds +DOWNLOAD_MAX_RETRY_DELAY = 30.0 # Maximum retry delay in seconds # Upload defaults MAX_CONCURRENT_PARTS = 4 # Maximum concurrent parts for multipart upload diff --git a/src/together/lib/resources/files.py b/src/together/lib/resources/files.py index b82e84d8..20d5be12 100644 --- a/src/together/lib/resources/files.py +++ b/src/together/lib/resources/files.py @@ -3,6 +3,7 @@ import os import math import stat +import time import uuid import shutil import asyncio @@ -29,12 +30,15 @@ MAX_MULTIPART_PARTS, TARGET_PART_SIZE_MB, MAX_CONCURRENT_PARTS, + MAX_DOWNLOAD_RETRIES, MULTIPART_THRESHOLD_GB, + DOWNLOAD_MAX_RETRY_DELAY, MULTIPART_UPLOAD_TIMEOUT, + DOWNLOAD_INITIAL_RETRY_DELAY, ) from ..._resource import SyncAPIResource, AsyncAPIResource from ..types.error import DownloadError, FileTypeError -from ..._exceptions import APIStatusError, AuthenticationError +from ..._exceptions import APIStatusError, APIConnectionError, AuthenticationError log: logging.Logger = logging.getLogger(__name__) @@ -198,6 +202,11 @@ def download( assert file_size != 0, "Unable to retrieve remote file." + # Download with retry logic + bytes_downloaded = 0 + retry_count = 0 + retry_delay = DOWNLOAD_INITIAL_RETRY_DELAY + with tqdm( total=file_size, unit="B", @@ -205,14 +214,64 @@ def download( desc=f"Downloading file {file_path.name}", disable=bool(DISABLE_TQDM), ) as pbar: - for chunk in response.iter_bytes(DOWNLOAD_BLOCK_SIZE): - pbar.update(len(chunk)) - temp_file.write(chunk) # type: ignore + while bytes_downloaded < file_size: + try: + # If this is a retry, close the previous response and create a new one with Range header + if bytes_downloaded > 0: + response.close() + + log.info(f"Resuming download from byte {bytes_downloaded}") + response = self._client.get( + path=url, + cast_to=httpx.Response, + stream=True, + options=RequestOptions( + headers={"Range": f"bytes={bytes_downloaded}-"}, + ), + ) + + # Download chunks + for chunk in response.iter_bytes(DOWNLOAD_BLOCK_SIZE): + temp_file.write(chunk) # type: ignore + bytes_downloaded += len(chunk) + pbar.update(len(chunk)) + + # Successfully completed download + break + + except (httpx.RequestError, httpx.StreamError, APIConnectionError) as e: + if retry_count >= MAX_DOWNLOAD_RETRIES: + log.error(f"Download failed after {retry_count} retries") + raise DownloadError( + f"Download failed after {retry_count} retries. Last error: {str(e)}" + ) from e + + retry_count += 1 + log.warning( + f"Download interrupted at {bytes_downloaded}/{file_size} bytes. " + f"Retry {retry_count}/{MAX_DOWNLOAD_RETRIES} in {retry_delay}s..." + ) + time.sleep(retry_delay) + + # Exponential backoff with max delay cap + retry_delay = min(retry_delay * 2, DOWNLOAD_MAX_RETRY_DELAY) + + except APIStatusError as e: + # For API errors, don't retry + log.error(f"API error during download: {e}") + raise APIStatusError( + "Error downloading file", + response=e.response, + body=e.response, + ) from e + + # Close the response + response.close() # Raise exception if remote file size does not match downloaded file size if os.stat(temp_file.name).st_size != file_size: - DownloadError( - f"Downloaded file size `{pbar.n}` bytes does not match remote file size `{file_size}` bytes." + raise DownloadError( + f"Downloaded file size `{bytes_downloaded}` bytes does not match remote file size `{file_size}` bytes." ) # Moves temp file to output file path