diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 521804da..19cfee53 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "2.0.0-alpha.10" + ".": "2.0.0-alpha.11" } \ No newline at end of file diff --git a/.stats.yml b/.stats.yml index 68f013aa..97040984 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ configured_endpoints: 44 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/togetherai%2Ftogetherai-291c169d09f5ccc52f1e20f6f239db136003f4735ebd82f14f10cfdf96bb88fd.yml -openapi_spec_hash: 241fba23e79ab8bcfb06c7781c01aa27 +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/togetherai%2Ftogetherai-817bdc0e9a5082575f07386056968f56af20cbc40cbbc716ab4b8c4ec9220b53.yml +openapi_spec_hash: 30b3f6d251dfd02bca8ffa3f755e7574 config_hash: 9749f2f8998aa6b15452b2187ff675b9 diff --git a/CHANGELOG.md b/CHANGELOG.md index 1dab4250..c5ed4e2b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,20 @@ # Changelog +## 2.0.0-alpha.11 (2025-12-16) + +Full Changelog: [v2.0.0-alpha.10...v2.0.0-alpha.11](https://github.com/togethercomputer/together-py/compare/v2.0.0-alpha.10...v2.0.0-alpha.11) + +### Features + +* **api:** api update ([17ad3ec](https://github.com/togethercomputer/together-py/commit/17ad3ec91a06a7e886252d4b688c3a9e217a3799)) +* **api:** api update ([ebc3414](https://github.com/togethercomputer/together-py/commit/ebc3414e28db0309fef5aeed456e242048b5d13c)) +* **files:** add support for string alternative to file upload type ([db59ed6](https://github.com/togethercomputer/together-py/commit/db59ed6235f2e18db100a72084c2fefc22354d15)) + + +### Chores + +* **internal:** add missing files argument to base client ([6977285](https://github.com/togethercomputer/together-py/commit/69772856908b8378c74eed382735523e91011d90)) + ## 2.0.0-alpha.10 (2025-12-15) Full Changelog: [v2.0.0-alpha.9...v2.0.0-alpha.10](https://github.com/togethercomputer/together-py/compare/v2.0.0-alpha.9...v2.0.0-alpha.10) diff --git a/pyproject.toml b/pyproject.toml index 218b08b5..93893bcc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "together" -version = "2.0.0-alpha.10" +version = "2.0.0-alpha.11" description = "The official Python library for the together API" dynamic = ["readme"] license = "Apache-2.0" diff --git a/src/together/_base_client.py b/src/together/_base_client.py index 0ffa0d3c..f63640d4 100644 --- a/src/together/_base_client.py +++ b/src/together/_base_client.py @@ -1247,9 +1247,12 @@ def patch( *, cast_to: Type[ResponseT], body: Body | None = None, + files: RequestFiles | None = None, options: RequestOptions = {}, ) -> ResponseT: - opts = FinalRequestOptions.construct(method="patch", url=path, json_data=body, **options) + opts = FinalRequestOptions.construct( + method="patch", url=path, json_data=body, files=to_httpx_files(files), **options + ) return self.request(cast_to, opts) def put( @@ -1767,9 +1770,12 @@ async def patch( *, cast_to: Type[ResponseT], body: Body | None = None, + files: RequestFiles | None = None, options: RequestOptions = {}, ) -> ResponseT: - opts = FinalRequestOptions.construct(method="patch", url=path, json_data=body, **options) + opts = FinalRequestOptions.construct( + method="patch", url=path, json_data=body, files=to_httpx_files(files), **options + ) return await self.request(cast_to, opts) async def put( diff --git a/src/together/_version.py b/src/together/_version.py index ec189fa4..14e05b77 100644 --- a/src/together/_version.py +++ b/src/together/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "together" -__version__ = "2.0.0-alpha.10" # x-release-please-version +__version__ = "2.0.0-alpha.11" # x-release-please-version diff --git a/src/together/lib/cli/api/fine_tuning.py b/src/together/lib/cli/api/fine_tuning.py index 4b3df577..26ab89a5 100644 --- a/src/together/lib/cli/api/fine_tuning.py +++ b/src/together/lib/cli/api/fine_tuning.py @@ -10,6 +10,7 @@ import click from rich import print as rprint from tabulate import tabulate +from rich.json import JSON from click.core import ParameterSource # type: ignore[attr-defined] from together import Together @@ -17,7 +18,7 @@ from together._types import NOT_GIVEN, NotGiven from together.lib.utils import log_warn from together.lib.utils.tools import format_timestamp, finetune_price_to_dollars -from together.lib.cli.api.utils import INT_WITH_MAX, BOOL_WITH_AUTO +from together.lib.cli.api.utils import INT_WITH_MAX, BOOL_WITH_AUTO, generate_progress_bar from together.lib.resources.files import DownloadManager from together.lib.utils.serializer import datetime_serializer from together.types.finetune_response import TrainingTypeFullTrainingType, TrainingTypeLoRaTrainingType @@ -361,7 +362,7 @@ def create( rpo_alpha=rpo_alpha or 0, simpo_gamma=simpo_gamma or 0, ) - + finetune_price_estimation_result = client.fine_tuning.estimate_price( training_file=training_file, validation_file=validation_file, @@ -425,6 +426,9 @@ def list(ctx: click.Context) -> None: "Price": f"""${ finetune_price_to_dollars(float(str(i.total_price))) }""", # convert to string for mypy typing + "Progress": generate_progress_bar( + i, datetime.now().astimezone(), use_rich=False + ), } ) table = tabulate(display_list, headers="keys", tablefmt="grid", showindex=True) @@ -444,7 +448,12 @@ def retrieve(ctx: click.Context, fine_tune_id: str) -> None: # remove events from response for cleaner output response.events = None - click.echo(json.dumps(response.model_dump(exclude_none=True), indent=4)) + rprint(JSON.from_data(response.model_json_schema())) + progress_text = generate_progress_bar( + response, datetime.now().astimezone(), use_rich=True + ) + prefix = f"Status: [bold]{response.status}[/bold]," + rprint(f"{prefix} {progress_text}") @fine_tuning.command() diff --git a/src/together/lib/cli/api/utils.py b/src/together/lib/cli/api/utils.py index 5c7c8241..eabb4d07 100644 --- a/src/together/lib/cli/api/utils.py +++ b/src/together/lib/cli/api/utils.py @@ -1,18 +1,25 @@ from __future__ import annotations -from typing import Literal +import re +import math +from typing import List, Union, Literal from gettext import gettext as _ -from typing_extensions import override +from datetime import datetime import click +from together.lib.types.fine_tuning import COMPLETED_STATUSES, FinetuneResponse +from together.types.finetune_response import FinetuneResponse as _FinetuneResponse +from together.types.fine_tuning_list_response import Data + +_PROGRESS_BAR_WIDTH = 40 + class AutoIntParamType(click.ParamType): name = "integer_or_max" _number_class = int - @override - def convert( + def convert( # pyright: ignore[reportImplicitOverride] self, value: str, param: click.Parameter | None, ctx: click.Context | None ) -> int | Literal["max"] | None: if value == "max": @@ -21,7 +28,9 @@ def convert( return int(value) except ValueError: self.fail( - _("{value!r} is not a valid {number_type}.").format(value=value, number_type=self.name), + _("{value!r} is not a valid {number_type}.").format( + value=value, number_type=self.name + ), param, ctx, ) @@ -30,8 +39,7 @@ def convert( class BooleanWithAutoParamType(click.ParamType): name = "boolean_or_auto" - @override - def convert( + def convert( # pyright: ignore[reportImplicitOverride] self, value: str, param: click.Parameter | None, ctx: click.Context | None ) -> bool | Literal["auto"] | None: if value == "auto": @@ -40,7 +48,9 @@ def convert( return bool(value) except ValueError: self.fail( - _("{value!r} is not a valid {type}.").format(value=value, type=self.name), + _("{value!r} is not a valid {type}.").format( + value=value, type=self.name + ), param, ctx, ) @@ -48,3 +58,82 @@ def convert( INT_WITH_MAX = AutoIntParamType() BOOL_WITH_AUTO = BooleanWithAutoParamType() + + +def _human_readable_time(timedelta: float) -> str: + """Convert a timedelta to a compact human-readble string + Examples: + 00:00:10 -> 10s + 01:23:45 -> 1h 23min 45s + 1 Month 23 days 04:56:07 -> 1month 23d 4h 56min 7s + Args: + timedelta (float): The timedelta in seconds to convert. + Returns: + A string representing the timedelta in a human-readable format. + """ + units = [ + (30 * 24 * 60 * 60, "month"), # 30 days + (24 * 60 * 60, "d"), + (60 * 60, "h"), + (60, "min"), + (1, "s"), + ] + + total_seconds = int(timedelta) + parts: List[str] = [] + + for unit_seconds, unit_name in units: + if total_seconds >= unit_seconds: + value = total_seconds // unit_seconds + total_seconds %= unit_seconds + parts.append(f"{value}{unit_name}") + + return " ".join(parts) if parts else "0s" + + +def generate_progress_bar( + finetune_job: Union[Data, FinetuneResponse, _FinetuneResponse], current_time: datetime, use_rich: bool = False +) -> str: + """Generate a progress bar for a finetune job. + Args: + finetune_job: The finetune job to generate a progress bar for. + current_time: The current time. + use_rich: Whether to use rich formatting. + Returns: + A string representing the progress bar. + """ + progress = "Progress: [bold red]unavailable[/bold red]" + if finetune_job.status in COMPLETED_STATUSES: + progress = "Progress: [bold green]completed[/bold green]" + elif finetune_job.updated_at is not None: + update_at = finetune_job.updated_at.astimezone() + + if finetune_job.progress is not None: + if current_time < update_at: + return progress + + if not finetune_job.progress.estimate_available: + return progress + + if finetune_job.progress.seconds_remaining <= 0: + return progress + + elapsed_time = (current_time - update_at).total_seconds() + ratio_filled = min( + elapsed_time / finetune_job.progress.seconds_remaining, 1.0 + ) + percentage = ratio_filled * 100 + filled = math.ceil(ratio_filled * _PROGRESS_BAR_WIDTH) + bar = "█" * filled + "░" * (_PROGRESS_BAR_WIDTH - filled) + time_left = "N/A" + if finetune_job.progress.seconds_remaining > elapsed_time: + time_left = _human_readable_time( + finetune_job.progress.seconds_remaining - elapsed_time + ) + time_text = f"{time_left} left" + progress = f"Progress: {bar} [bold]{percentage:>3.0f}%[/bold] [yellow]{time_text}[/yellow]" + + if use_rich: + return progress + + return re.sub(r"\[/?[^\]]+\]", "", progress) diff --git a/src/together/lib/constants.py b/src/together/lib/constants.py index ed64b08e..bae83c64 100644 --- a/src/together/lib/constants.py +++ b/src/together/lib/constants.py @@ -14,6 +14,9 @@ # Download defaults DOWNLOAD_BLOCK_SIZE = 10 * 1024 * 1024 # 10 MB DISABLE_TQDM = False +MAX_DOWNLOAD_RETRIES = 5 # Maximum retries for download failures +DOWNLOAD_INITIAL_RETRY_DELAY = 1.0 # Initial retry delay in seconds +DOWNLOAD_MAX_RETRY_DELAY = 30.0 # Maximum retry delay in seconds # Upload defaults MAX_CONCURRENT_PARTS = 4 # Maximum concurrent parts for multipart upload diff --git a/src/together/lib/resources/files.py b/src/together/lib/resources/files.py index b82e84d8..20d5be12 100644 --- a/src/together/lib/resources/files.py +++ b/src/together/lib/resources/files.py @@ -3,6 +3,7 @@ import os import math import stat +import time import uuid import shutil import asyncio @@ -29,12 +30,15 @@ MAX_MULTIPART_PARTS, TARGET_PART_SIZE_MB, MAX_CONCURRENT_PARTS, + MAX_DOWNLOAD_RETRIES, MULTIPART_THRESHOLD_GB, + DOWNLOAD_MAX_RETRY_DELAY, MULTIPART_UPLOAD_TIMEOUT, + DOWNLOAD_INITIAL_RETRY_DELAY, ) from ..._resource import SyncAPIResource, AsyncAPIResource from ..types.error import DownloadError, FileTypeError -from ..._exceptions import APIStatusError, AuthenticationError +from ..._exceptions import APIStatusError, APIConnectionError, AuthenticationError log: logging.Logger = logging.getLogger(__name__) @@ -198,6 +202,11 @@ def download( assert file_size != 0, "Unable to retrieve remote file." + # Download with retry logic + bytes_downloaded = 0 + retry_count = 0 + retry_delay = DOWNLOAD_INITIAL_RETRY_DELAY + with tqdm( total=file_size, unit="B", @@ -205,14 +214,64 @@ def download( desc=f"Downloading file {file_path.name}", disable=bool(DISABLE_TQDM), ) as pbar: - for chunk in response.iter_bytes(DOWNLOAD_BLOCK_SIZE): - pbar.update(len(chunk)) - temp_file.write(chunk) # type: ignore + while bytes_downloaded < file_size: + try: + # If this is a retry, close the previous response and create a new one with Range header + if bytes_downloaded > 0: + response.close() + + log.info(f"Resuming download from byte {bytes_downloaded}") + response = self._client.get( + path=url, + cast_to=httpx.Response, + stream=True, + options=RequestOptions( + headers={"Range": f"bytes={bytes_downloaded}-"}, + ), + ) + + # Download chunks + for chunk in response.iter_bytes(DOWNLOAD_BLOCK_SIZE): + temp_file.write(chunk) # type: ignore + bytes_downloaded += len(chunk) + pbar.update(len(chunk)) + + # Successfully completed download + break + + except (httpx.RequestError, httpx.StreamError, APIConnectionError) as e: + if retry_count >= MAX_DOWNLOAD_RETRIES: + log.error(f"Download failed after {retry_count} retries") + raise DownloadError( + f"Download failed after {retry_count} retries. Last error: {str(e)}" + ) from e + + retry_count += 1 + log.warning( + f"Download interrupted at {bytes_downloaded}/{file_size} bytes. " + f"Retry {retry_count}/{MAX_DOWNLOAD_RETRIES} in {retry_delay}s..." + ) + time.sleep(retry_delay) + + # Exponential backoff with max delay cap + retry_delay = min(retry_delay * 2, DOWNLOAD_MAX_RETRY_DELAY) + + except APIStatusError as e: + # For API errors, don't retry + log.error(f"API error during download: {e}") + raise APIStatusError( + "Error downloading file", + response=e.response, + body=e.response, + ) from e + + # Close the response + response.close() # Raise exception if remote file size does not match downloaded file size if os.stat(temp_file.name).st_size != file_size: - DownloadError( - f"Downloaded file size `{pbar.n}` bytes does not match remote file size `{file_size}` bytes." + raise DownloadError( + f"Downloaded file size `{bytes_downloaded}` bytes does not match remote file size `{file_size}` bytes." ) # Moves temp file to output file path diff --git a/src/together/lib/types/fine_tuning.py b/src/together/lib/types/fine_tuning.py index 87f96857..d3888857 100644 --- a/src/together/lib/types/fine_tuning.py +++ b/src/together/lib/types/fine_tuning.py @@ -25,6 +25,14 @@ class FinetuneJobStatus(str, Enum): STATUS_COMPLETED = "completed" +COMPLETED_STATUSES = [ + FinetuneJobStatus.STATUS_ERROR, + FinetuneJobStatus.STATUS_USER_ERROR, + FinetuneJobStatus.STATUS_COMPLETED, + FinetuneJobStatus.STATUS_CANCELLED, +] + + class FinetuneEventType(str, Enum): """ Fine-tune job event types @@ -260,6 +268,15 @@ class UnknownLRScheduler(BaseModel): ] +class FinetuneProgress(BaseModel): + """ + Fine-tune job progress + """ + + estimate_available: bool = False + seconds_remaining: float = 0 + + class FinetuneResponse(BaseModel): """ Fine-tune API response type @@ -393,6 +410,8 @@ class FinetuneResponse(BaseModel): training_file_size: Optional[int] = Field(None, alias="TrainingFileSize") train_on_inputs: Union[StrictBool, Literal["auto"], None] = "auto" + progress: Union[FinetuneProgress, None] = None + @classmethod def validate_training_type(cls, v: TrainingType) -> TrainingType: if v.type == "Full" or v.type == "": diff --git a/src/together/resources/audio/transcriptions.py b/src/together/resources/audio/transcriptions.py index 7c7cb7c2..03cb8d72 100644 --- a/src/together/resources/audio/transcriptions.py +++ b/src/together/resources/audio/transcriptions.py @@ -47,7 +47,7 @@ def with_streaming_response(self) -> TranscriptionsResourceWithStreamingResponse def create( self, *, - file: FileTypes, + file: Union[FileTypes, str], diarize: bool | Omit = omit, language: str | Omit = omit, max_speakers: int | Omit = omit, @@ -68,7 +68,8 @@ def create( Transcribes audio into text Args: - file: Audio file to transcribe + file: Audio file upload or public HTTP/HTTPS URL. Supported formats .wav, .mp3, .m4a, + .webm, .flac. diarize: Whether to enable speaker diarization. When enabled, you will get the speaker id for each word in the transcription. In the response, in the words array, you @@ -168,7 +169,7 @@ def with_streaming_response(self) -> AsyncTranscriptionsResourceWithStreamingRes async def create( self, *, - file: FileTypes, + file: Union[FileTypes, str], diarize: bool | Omit = omit, language: str | Omit = omit, max_speakers: int | Omit = omit, @@ -189,7 +190,8 @@ async def create( Transcribes audio into text Args: - file: Audio file to transcribe + file: Audio file upload or public HTTP/HTTPS URL. Supported formats .wav, .mp3, .m4a, + .webm, .flac. diarize: Whether to enable speaker diarization. When enabled, you will get the speaker id for each word in the transcription. In the response, in the words array, you diff --git a/src/together/resources/audio/translations.py b/src/together/resources/audio/translations.py index 393a5ae2..56136915 100644 --- a/src/together/resources/audio/translations.py +++ b/src/together/resources/audio/translations.py @@ -47,7 +47,7 @@ def with_streaming_response(self) -> TranslationsResourceWithStreamingResponse: def create( self, *, - file: FileTypes, + file: Union[FileTypes, str], language: str | Omit = omit, model: Literal["openai/whisper-large-v3"] | Omit = omit, prompt: str | Omit = omit, @@ -65,7 +65,8 @@ def create( Translates audio into English Args: - file: Audio file to translate + file: Audio file upload or public HTTP/HTTPS URL. Supported formats .wav, .mp3, .m4a, + .webm, .flac. language: Target output language. Optional ISO 639-1 language code. If omitted, language is set to English. @@ -145,7 +146,7 @@ def with_streaming_response(self) -> AsyncTranslationsResourceWithStreamingRespo async def create( self, *, - file: FileTypes, + file: Union[FileTypes, str], language: str | Omit = omit, model: Literal["openai/whisper-large-v3"] | Omit = omit, prompt: str | Omit = omit, @@ -163,7 +164,8 @@ async def create( Translates audio into English Args: - file: Audio file to translate + file: Audio file upload or public HTTP/HTTPS URL. Supported formats .wav, .mp3, .m4a, + .webm, .flac. language: Target output language. Optional ISO 639-1 language code. If omitted, language is set to English. diff --git a/src/together/types/audio/transcription_create_params.py b/src/together/types/audio/transcription_create_params.py index 851ccb00..b28fab6f 100644 --- a/src/together/types/audio/transcription_create_params.py +++ b/src/together/types/audio/transcription_create_params.py @@ -11,8 +11,11 @@ class TranscriptionCreateParams(TypedDict, total=False): - file: Required[FileTypes] - """Audio file to transcribe""" + file: Required[Union[FileTypes, str]] + """Audio file upload or public HTTP/HTTPS URL. + + Supported formats .wav, .mp3, .m4a, .webm, .flac. + """ diarize: bool """Whether to enable speaker diarization. diff --git a/src/together/types/audio/translation_create_params.py b/src/together/types/audio/translation_create_params.py index 088e98a1..5c944f5a 100644 --- a/src/together/types/audio/translation_create_params.py +++ b/src/together/types/audio/translation_create_params.py @@ -11,8 +11,11 @@ class TranslationCreateParams(TypedDict, total=False): - file: Required[FileTypes] - """Audio file to translate""" + file: Required[Union[FileTypes, str]] + """Audio file upload or public HTTP/HTTPS URL. + + Supported formats .wav, .mp3, .m4a, .webm, .flac. + """ language: str """Target output language. diff --git a/src/together/types/fine_tuning_cancel_response.py b/src/together/types/fine_tuning_cancel_response.py index 8b41d200..eb60a63b 100644 --- a/src/together/types/fine_tuning_cancel_response.py +++ b/src/together/types/fine_tuning_cancel_response.py @@ -15,6 +15,7 @@ "LrSchedulerLrSchedulerArgs", "LrSchedulerLrSchedulerArgsLinearLrSchedulerArgs", "LrSchedulerLrSchedulerArgsCosineLrSchedulerArgs", + "Progress", "TrainingMethod", "TrainingMethodTrainingMethodSft", "TrainingMethodTrainingMethodDpo", @@ -50,6 +51,16 @@ class LrScheduler(BaseModel): lr_scheduler_args: Optional[LrSchedulerLrSchedulerArgs] = None +class Progress(BaseModel): + """Progress information for the fine-tuning job""" + + estimate_available: bool + """Whether time estimate is available""" + + seconds_remaining: int + """Estimated time remaining in seconds for the fine-tuning job to next state""" + + class TrainingMethodTrainingMethodSft(BaseModel): method: Literal["sft"] @@ -163,6 +174,9 @@ class FineTuningCancelResponse(BaseModel): owner_address: Optional[str] = None """Owner address information""" + progress: Optional[Progress] = None + """Progress information for the fine-tuning job""" + suffix: Optional[str] = None """Suffix added to the fine-tuned model name""" diff --git a/src/together/types/fine_tuning_list_response.py b/src/together/types/fine_tuning_list_response.py index dc9f740b..b4fab462 100644 --- a/src/together/types/fine_tuning_list_response.py +++ b/src/together/types/fine_tuning_list_response.py @@ -16,6 +16,7 @@ "DataLrSchedulerLrSchedulerArgs", "DataLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs", "DataLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs", + "DataProgress", "DataTrainingMethod", "DataTrainingMethodTrainingMethodSft", "DataTrainingMethodTrainingMethodDpo", @@ -51,6 +52,16 @@ class DataLrScheduler(BaseModel): lr_scheduler_args: Optional[DataLrSchedulerLrSchedulerArgs] = None +class DataProgress(BaseModel): + """Progress information for the fine-tuning job""" + + estimate_available: bool + """Whether time estimate is available""" + + seconds_remaining: int + """Estimated time remaining in seconds for the fine-tuning job to next state""" + + class DataTrainingMethodTrainingMethodSft(BaseModel): method: Literal["sft"] @@ -164,6 +175,9 @@ class Data(BaseModel): owner_address: Optional[str] = None """Owner address information""" + progress: Optional[DataProgress] = None + """Progress information for the fine-tuning job""" + suffix: Optional[str] = None """Suffix added to the fine-tuned model name""" diff --git a/src/together/types/finetune_response.py b/src/together/types/finetune_response.py index d03e169f..5e02cee1 100644 --- a/src/together/types/finetune_response.py +++ b/src/together/types/finetune_response.py @@ -1,6 +1,7 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from typing import List, Union, Optional +from datetime import datetime from typing_extensions import Literal, TypeAlias from pydantic import Field as FieldInfo @@ -14,6 +15,7 @@ "LrSchedulerLrSchedulerArgs", "LrSchedulerLrSchedulerArgsLinearLrSchedulerArgs", "LrSchedulerLrSchedulerArgsCosineLrSchedulerArgs", + "Progress", "TrainingMethod", "TrainingMethodTrainingMethodSft", "TrainingMethodTrainingMethodDpo", @@ -47,6 +49,16 @@ class LrScheduler(BaseModel): lr_scheduler_args: Optional[LrSchedulerLrSchedulerArgs] = None +class Progress(BaseModel): + """Progress information for a fine-tuning job""" + + estimate_available: bool + """Whether time estimate is available""" + + seconds_remaining: int + """Estimated time remaining in seconds for the fine-tuning job to next state""" + + class TrainingMethodTrainingMethodSft(BaseModel): method: Literal["sft"] @@ -110,7 +122,7 @@ class FinetuneResponse(BaseModel): batch_size: Union[int, Literal["max"], None] = None - created_at: Optional[str] = None + created_at: Optional[datetime] = None epochs_completed: Optional[int] = None @@ -146,6 +158,9 @@ class FinetuneResponse(BaseModel): param_count: Optional[int] = None + progress: Optional[Progress] = None + """Progress information for a fine-tuning job""" + queue_depth: Optional[int] = None token_count: Optional[int] = None @@ -164,7 +179,7 @@ class FinetuneResponse(BaseModel): trainingfile_size: Optional[int] = None - updated_at: Optional[str] = None + updated_at: Optional[datetime] = None validation_file: Optional[str] = None diff --git a/tests/test_cli_utils.py b/tests/test_cli_utils.py new file mode 100644 index 00000000..230fe34f --- /dev/null +++ b/tests/test_cli_utils.py @@ -0,0 +1,414 @@ +from typing import Union +from datetime import datetime, timezone + +import pytest +from zoneinfo import ZoneInfo + +from together.lib.cli.api.utils import generate_progress_bar +from together.lib.types.fine_tuning import ( + FinetuneProgress, + FinetuneResponse, + FinetuneJobStatus, +) + + +def create_finetune_response( + status: FinetuneJobStatus = FinetuneJobStatus.STATUS_RUNNING, + updated_at: datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + progress: Union[FinetuneProgress, None] = None, + job_id: str = "ft-test-123", +) -> FinetuneResponse: + """Helper function to create FinetuneResponse objects for testing. + + Args: + status: The job status. + updated_at: The updated timestamp in ISO format. + progress: Optional FinetuneProgress object. + job_id: The fine-tune job ID. + + Returns: + A FinetuneResponse object for testing. + """ + return FinetuneResponse( + id=job_id, + progress=progress, + updated_at=updated_at, + status=status, + created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + model_output_name="test_model", + adapter_output_name="test_adapter", + TrainingFileNumLines=0, + TrainingFileSize=0, + ) + + +class TestGenerateProgressBarGeneral: + """General test cases for normal operation.""" + + def test_progress_unavailable_when_none(self): + """Test that progress shows unavailable when progress field is None.""" + current_time = datetime(2024, 1, 1, 12, 0, 10, tzinfo=timezone.utc) + finetune_job = create_finetune_response(progress=None) + + result = generate_progress_bar(finetune_job, current_time, use_rich=True) + + assert result == "Progress: [bold red]unavailable[/bold red]" + + def test_progress_unavailable_when_not_set(self): + """Test that progress shows unavailable when field is not provided.""" + current_time = datetime(2024, 1, 1, 12, 0, 10, tzinfo=timezone.utc) + finetune_job = create_finetune_response() + + result = generate_progress_bar(finetune_job, current_time, use_rich=True) + + assert result == "Progress: [bold red]unavailable[/bold red]" + + def test_progress_bar_at_start(self): + """Test progress bar display when job just started (low percentage).""" + current_time = datetime(2024, 1, 1, 12, 0, 10, tzinfo=timezone.utc) + finetune_job = create_finetune_response( + progress=FinetuneProgress(estimate_available=True, seconds_remaining=1000.0) + ) + + result = generate_progress_bar(finetune_job, current_time, use_rich=True) + + # 10 seconds elapsed / 1000 seconds remaining = 0.01 ratio = 1% progress + # 0.01 * 40 = 0.4, ceil(0.4) = 1 filled bar + assert ( + result + == "Progress: █░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░ [bold] 1%[/bold] [yellow]16min 30s left[/yellow]" + ) + + def test_progress_bar_at_midpoint(self): + """Test progress bar at approximately 50% completion.""" + current_time = datetime(2024, 1, 1, 12, 1, 0, tzinfo=timezone.utc) + finetune_job = create_finetune_response( + progress=FinetuneProgress(estimate_available=True, seconds_remaining=60.0) + ) + + result = generate_progress_bar(finetune_job, current_time, use_rich=True) + + # 60 seconds elapsed / 60 seconds remaining = 1.0 ratio = 100% progress + # 1.0 * 40 = 40 filled bars + assert ( + result + == "Progress: ████████████████████████████████████████ [bold]100%[/bold] [yellow]N/A left[/yellow]" + ) + + def test_progress_bar_near_completion(self): + """Test progress bar when job is almost complete.""" + current_time = datetime(2024, 1, 1, 12, 5, 0, tzinfo=timezone.utc) + finetune_job = create_finetune_response( + progress=FinetuneProgress(estimate_available=True, seconds_remaining=30.0) + ) + + result = generate_progress_bar(finetune_job, current_time, use_rich=True) + + # 300 seconds elapsed / 30 seconds remaining = 10.0 ratio = 1000% progress + # 10.0 * 40 = 400, ceil(400) = 400, but width is 40 so all filled + assert ( + result + == "Progress: ████████████████████████████████████████ [bold]100%[/bold] [yellow]N/A left[/yellow]" + ) + + def test_progress_bar_contains_rich_formatting(self): + """Test that progress bar includes expected Rich markup formatting.""" + current_time = datetime(2024, 1, 1, 12, 0, 30, tzinfo=timezone.utc) + finetune_job = create_finetune_response( + progress=FinetuneProgress(estimate_available=True, seconds_remaining=60.0) + ) + + result = generate_progress_bar(finetune_job, current_time, use_rich=True) + + # 30 seconds elapsed / 60 seconds remaining = 0.5 ratio = 50% progress + # 0.5 * 40 = 20 filled bars + assert ( + result + == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ [bold] 50%[/bold] [yellow]30s left[/yellow]" + ) + + +class TestGenerateProgressBarRichFormatting: + """Test cases for use_rich parameter.""" + + def test_rich_formatting_removed_when_use_rich_false(self): + """Test that rich formatting tags are removed when use_rich=False.""" + current_time = datetime(2024, 1, 1, 12, 0, 30, tzinfo=timezone.utc) + finetune_job = create_finetune_response( + progress=FinetuneProgress(estimate_available=True, seconds_remaining=60.0) + ) + + result = generate_progress_bar(finetune_job, current_time, use_rich=False) + + assert ( + result == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ 50% 30s left" + ) + + def test_rich_formatting_preserved_when_use_rich_true(self): + """Test that rich formatting tags are preserved when use_rich=True.""" + current_time = datetime(2024, 1, 1, 12, 0, 30, tzinfo=timezone.utc) + finetune_job = create_finetune_response( + progress=FinetuneProgress(estimate_available=True, seconds_remaining=60.0) + ) + + result = generate_progress_bar(finetune_job, current_time, use_rich=True) + + assert ( + result + == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ [bold] 50%[/bold] [yellow]30s left[/yellow]" + ) + + def test_completed_status_formatting_removed(self): + """Test that completed status formatting is removed when use_rich=False.""" + current_time = datetime(2024, 1, 1, 12, 0, 10, tzinfo=timezone.utc) + finetune_job = create_finetune_response( + status=FinetuneJobStatus.STATUS_COMPLETED, progress=None + ) + + result = generate_progress_bar(finetune_job, current_time, use_rich=False) + + assert result == "Progress: completed" + + def test_unavailable_status_formatting_removed(self): + """Test that unavailable status formatting is removed when use_rich=False.""" + current_time = datetime(2024, 1, 1, 12, 0, 10, tzinfo=timezone.utc) + finetune_job = create_finetune_response(progress=None) + + result = generate_progress_bar(finetune_job, current_time, use_rich=False) + + assert result == "Progress: unavailable" + + def test_rich_formatting_removed_at_completion(self): + """Test that rich formatting is removed at 100% when use_rich=False.""" + current_time = datetime(2024, 1, 1, 12, 1, 0, tzinfo=timezone.utc) + finetune_job = create_finetune_response( + progress=FinetuneProgress(estimate_available=True, seconds_remaining=60.0) + ) + + result = generate_progress_bar(finetune_job, current_time, use_rich=False) + + assert ( + result == "Progress: ████████████████████████████████████████ 100% N/A left" + ) + + def test_default_behavior_strips_formatting(self): + """Test that rich formatting is removed by default (use_rich not specified).""" + current_time = datetime(2024, 1, 1, 12, 0, 30, tzinfo=timezone.utc) + finetune_job = create_finetune_response( + progress=FinetuneProgress(estimate_available=True, seconds_remaining=60.0) + ) + + result = generate_progress_bar(finetune_job, current_time) + + assert ( + result == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ 50% 30s left" + ) + + def test_content_consistency_between_modes(self): + """Test that use_rich=True and use_rich=False have same content, just different formatting.""" + import re + + current_time = datetime(2024, 1, 1, 12, 0, 30, tzinfo=timezone.utc) + finetune_job = create_finetune_response( + progress=FinetuneProgress(estimate_available=True, seconds_remaining=60.0) + ) + + result_with_rich = generate_progress_bar( + finetune_job, current_time, use_rich=True + ) + result_without_rich = generate_progress_bar( + finetune_job, current_time, use_rich=False + ) + + stripped_rich = re.sub(r"\[/?[^\]]+\]", "", result_with_rich) + assert stripped_rich == result_without_rich + + def test_all_rich_tag_types_removed(self): + """Test that all types of rich formatting tags are properly removed.""" + current_time = datetime(2024, 1, 1, 12, 0, 10, tzinfo=timezone.utc) + + # Test with completed status (has [bold green] tags) + completed_job = create_finetune_response( + status=FinetuneJobStatus.STATUS_COMPLETED, progress=None + ) + result_completed = generate_progress_bar( + completed_job, current_time, use_rich=False + ) + assert result_completed == "Progress: completed" + + # Test with unavailable status (has [bold red] tags) + unavailable_job = create_finetune_response(progress=None) + result_unavailable = generate_progress_bar( + unavailable_job, current_time, use_rich=False + ) + assert result_unavailable == "Progress: unavailable" + + @pytest.mark.parametrize( + "use_rich,expected_completed,expected_running", + [ + ( + True, + "Progress: [bold green]completed[/bold green]", + "Progress: ███████░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░ [bold] 17%[/bold] [yellow]50s left[/yellow]", + ), + ( + False, + "Progress: completed", + "Progress: ███████░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░ 17% 50s left", + ), + ], + ) + def test_rich_parameter_with_different_statuses( + self, use_rich: bool, expected_completed: str, expected_running: str + ): + """Test use_rich parameter works correctly with different job statuses.""" + current_time = datetime(2024, 1, 1, 12, 0, 10, tzinfo=timezone.utc) + + # Test completed status + completed_job = create_finetune_response( + status=FinetuneJobStatus.STATUS_COMPLETED, progress=None + ) + result = generate_progress_bar(completed_job, current_time, use_rich=use_rich) + assert result == expected_completed + + # Test running status + running_job = create_finetune_response( + progress=FinetuneProgress(estimate_available=True, seconds_remaining=60.0) + ) + result = generate_progress_bar(running_job, current_time, use_rich=use_rich) + assert result == expected_running + + def test_progress_percentage_1_percent(self): + """Test progress bar at 1% completion.""" + current_time = datetime(2024, 1, 1, 12, 0, 10, tzinfo=timezone.utc) + finetune_job = create_finetune_response( + progress=FinetuneProgress(estimate_available=True, seconds_remaining=1000.0) + ) + + result = generate_progress_bar(finetune_job, current_time, use_rich=False) + assert ( + result + == "Progress: █░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░ 1% 16min 30s left" + ) + + def test_progress_percentage_75_percent(self): + """Test progress bar at 75% completion.""" + current_time = datetime(2024, 1, 1, 12, 0, 45, tzinfo=timezone.utc) + finetune_job = create_finetune_response( + progress=FinetuneProgress(estimate_available=True, seconds_remaining=60.0) + ) + + result = generate_progress_bar(finetune_job, current_time, use_rich=False) + assert ( + result == "Progress: ██████████████████████████████░░░░░░░░░░ 75% 15s left" + ) + + +class TestGenerateProgressBarCornerCases: + """Corner cases and edge conditions.""" + + def test_zero_seconds_remaining(self): + """Test handling of zero seconds remaining (potential division by zero).""" + current_time = datetime(2024, 1, 1, 12, 0, 10, tzinfo=timezone.utc) + finetune_job = create_finetune_response( + progress=FinetuneProgress(estimate_available=True, seconds_remaining=0.0) + ) + + result = generate_progress_bar(finetune_job, current_time, use_rich=True) + + assert result == "Progress: [bold red]unavailable[/bold red]" + + def test_very_small_remaining_time(self): + """Test with very small remaining time (< 1 second).""" + current_time = datetime(2024, 1, 1, 12, 0, 5, tzinfo=timezone.utc) + finetune_job = create_finetune_response( + progress=FinetuneProgress(estimate_available=True, seconds_remaining=0.5) + ) + + result = generate_progress_bar(finetune_job, current_time, use_rich=True) + + assert ( + result + == "Progress: ████████████████████████████████████████ [bold]100%[/bold] [yellow]N/A left[/yellow]" + ) + + def test_very_large_remaining_time(self): + """Test with very large remaining time (hours).""" + current_time = datetime(2024, 1, 1, 12, 0, 30, tzinfo=timezone.utc) + finetune_job = create_finetune_response( + progress=FinetuneProgress( + estimate_available=True, seconds_remaining=36000.0 + ) + ) + + result = generate_progress_bar(finetune_job, current_time, use_rich=True) + + assert ( + result + == "Progress: █░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░ [bold] 0%[/bold] [yellow]9h 59min 30s left[/yellow]" + ) + + def test_job_exceeding_estimate(self): + """Test when elapsed time exceeds original estimate (>100% progress).""" + current_time = datetime(2024, 1, 1, 14, 0, 0, tzinfo=timezone.utc) + finetune_job = create_finetune_response( + progress=FinetuneProgress(estimate_available=True, seconds_remaining=60.0) + ) + + result = generate_progress_bar(finetune_job, current_time, use_rich=True) + + assert ( + result + == "Progress: ████████████████████████████████████████ [bold]100%[/bold] [yellow]N/A left[/yellow]" + ) + + def test_timezone_aware_datetime(self): + """Test with different timezone for updated_at.""" + current_time = datetime(2024, 1, 1, 12, 0, 30, tzinfo=timezone.utc) + finetune_job = create_finetune_response( + updated_at=datetime(2024, 1, 1, 7, 0, 0, tzinfo=ZoneInfo("EST")), # Same as 12:00:00 UTC (EST = UTC-5) + progress=FinetuneProgress(estimate_available=True, seconds_remaining=60.0), + ) + + result = generate_progress_bar(finetune_job, current_time, use_rich=True) + + assert ( + result + == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ [bold] 50%[/bold] [yellow]30s left[/yellow]" + ) + + def test_estimate_unavailable_flag(self): + """Test when estimate_available flag is False.""" + current_time = datetime(2024, 1, 1, 12, 0, 50, tzinfo=timezone.utc) + finetune_job = create_finetune_response( + progress=FinetuneProgress(estimate_available=False, seconds_remaining=100.0) + ) + + result = generate_progress_bar(finetune_job, current_time, use_rich=True) + + assert result == "Progress: [bold red]unavailable[/bold red]" + + def test_negative_elapsed_time_scenario(self): + """Test unusual case where current time appears before updated_at.""" + current_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + finetune_job = create_finetune_response( + updated_at=datetime(2024, 1, 1, 12, 0, 30, tzinfo=timezone.utc), # In the "future" + progress=FinetuneProgress(estimate_available=True, seconds_remaining=100.0), + ) + + result = generate_progress_bar(finetune_job, current_time, use_rich=True) + + assert result == "Progress: [bold red]unavailable[/bold red]" + + def test_unicode_progress_bars_preserved(self): + """Test that unicode characters in progress bars are preserved after tag removal.""" + current_time = datetime(2024, 1, 1, 12, 0, 30, tzinfo=timezone.utc) + finetune_job = create_finetune_response( + progress=FinetuneProgress(estimate_available=True, seconds_remaining=60.0) + ) + + result = generate_progress_bar(finetune_job, current_time, use_rich=False) + + assert ( + result == "Progress: ████████████████████░░░░░░░░░░░░░░░░░░░░ 50% 30s left" + ) diff --git a/uv.lock b/uv.lock index d16aaa7d..0143d6bf 100644 --- a/uv.lock +++ b/uv.lock @@ -1963,7 +1963,7 @@ wheels = [ [[package]] name = "together" -version = "2.0.0a8" +version = "2.0.0a10" source = { editable = "." } dependencies = [ { name = "anyio" },