Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/together/lib/cli/api/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
import click
from rich import print as rprint
from tabulate import tabulate
from rich.json import JSON
from click.core import ParameterSource # type: ignore[attr-defined]

from together import Together
from together.types import fine_tuning_estimate_price_params as pe_params
from together._types import NOT_GIVEN, NotGiven
from together.lib.utils import log_warn
from together.lib.utils.tools import format_timestamp, finetune_price_to_dollars
from together.lib.cli.api.utils import INT_WITH_MAX, BOOL_WITH_AUTO
from together.lib.cli.api.utils import INT_WITH_MAX, BOOL_WITH_AUTO, generate_progress_bar
from together.lib.resources.files import DownloadManager
from together.lib.utils.serializer import datetime_serializer
from together.types.finetune_response import TrainingTypeFullTrainingType, TrainingTypeLoRaTrainingType
Expand Down Expand Up @@ -361,7 +362,7 @@ def create(
rpo_alpha=rpo_alpha or 0,
simpo_gamma=simpo_gamma or 0,
)

finetune_price_estimation_result = client.fine_tuning.estimate_price(
training_file=training_file,
validation_file=validation_file,
Expand Down Expand Up @@ -425,6 +426,9 @@ def list(ctx: click.Context) -> None:
"Price": f"""${
finetune_price_to_dollars(float(str(i.total_price)))
}""", # convert to string for mypy typing
"Progress": generate_progress_bar(
i, datetime.now().astimezone(), use_rich=False
),
}
)
table = tabulate(display_list, headers="keys", tablefmt="grid", showindex=True)
Expand All @@ -444,7 +448,12 @@ def retrieve(ctx: click.Context, fine_tune_id: str) -> None:
# remove events from response for cleaner output
response.events = None

click.echo(json.dumps(response.model_dump(exclude_none=True), indent=4))
rprint(JSON.from_data(response.model_json_schema()))
progress_text = generate_progress_bar(
response, datetime.now().astimezone(), use_rich=True
)
prefix = f"Status: [bold]{response.status}[/bold],"
rprint(f"{prefix} {progress_text}")


@fine_tuning.command()
Expand Down
105 changes: 97 additions & 8 deletions src/together/lib/cli/api/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
from __future__ import annotations

from typing import Literal
import re
import math
from typing import List, Union, Literal
from gettext import gettext as _
from typing_extensions import override
from datetime import datetime

import click

from together.lib.types.fine_tuning import COMPLETED_STATUSES, FinetuneResponse
from together.types.finetune_response import FinetuneResponse as _FinetuneResponse
from together.types.fine_tuning_list_response import Data

_PROGRESS_BAR_WIDTH = 40


class AutoIntParamType(click.ParamType):
name = "integer_or_max"
_number_class = int

@override
def convert(
def convert( # pyright: ignore[reportImplicitOverride]
self, value: str, param: click.Parameter | None, ctx: click.Context | None
) -> int | Literal["max"] | None:
if value == "max":
Expand All @@ -21,7 +28,9 @@ def convert(
return int(value)
except ValueError:
self.fail(
_("{value!r} is not a valid {number_type}.").format(value=value, number_type=self.name),
_("{value!r} is not a valid {number_type}.").format(
value=value, number_type=self.name
),
param,
ctx,
)
Expand All @@ -30,8 +39,7 @@ def convert(
class BooleanWithAutoParamType(click.ParamType):
name = "boolean_or_auto"

@override
def convert(
def convert( # pyright: ignore[reportImplicitOverride]
self, value: str, param: click.Parameter | None, ctx: click.Context | None
) -> bool | Literal["auto"] | None:
if value == "auto":
Expand All @@ -40,11 +48,92 @@ def convert(
return bool(value)
except ValueError:
self.fail(
_("{value!r} is not a valid {type}.").format(value=value, type=self.name),
_("{value!r} is not a valid {type}.").format(
value=value, type=self.name
),
param,
ctx,
)


INT_WITH_MAX = AutoIntParamType()
BOOL_WITH_AUTO = BooleanWithAutoParamType()


def _human_readable_time(timedelta: float) -> str:
"""Convert a timedelta to a compact human-readble string
Examples:
00:00:10 -> 10s
01:23:45 -> 1h 23min 45s
1 Month 23 days 04:56:07 -> 1month 23d 4h 56min 7s
Args:
timedelta (float): The timedelta in seconds to convert.
Returns:
A string representing the timedelta in a human-readable format.
"""
units = [
(30 * 24 * 60 * 60, "month"), # 30 days
(24 * 60 * 60, "d"),
(60 * 60, "h"),
(60, "min"),
(1, "s"),
]

total_seconds = int(timedelta)
parts: List[str] = []

for unit_seconds, unit_name in units:
if total_seconds >= unit_seconds:
value = total_seconds // unit_seconds
total_seconds %= unit_seconds
parts.append(f"{value}{unit_name}")

return " ".join(parts) if parts else "0s"


def generate_progress_bar(
finetune_job: Union[Data, FinetuneResponse, _FinetuneResponse], current_time: datetime, use_rich: bool = False
) -> str:
"""Generate a progress bar for a finetune job.
Args:
finetune_job: The finetune job to generate a progress bar for.
current_time: The current time.
use_rich: Whether to use rich formatting.
Returns:
A string representing the progress bar.
"""
progress = "Progress: [bold red]unavailable[/bold red]"
if finetune_job.status in COMPLETED_STATUSES:
progress = "Progress: [bold green]completed[/bold green]"
elif finetune_job.updated_at is not None:
update_at = finetune_job.updated_at.astimezone()

if finetune_job.progress is not None:
if current_time < update_at:
return progress

if not finetune_job.progress.estimate_available:
return progress

if finetune_job.progress.seconds_remaining <= 0:
return progress

elapsed_time = (current_time - update_at).total_seconds()
ratio_filled = min(
elapsed_time / finetune_job.progress.seconds_remaining, 1.0
)
percentage = ratio_filled * 100
filled = math.ceil(ratio_filled * _PROGRESS_BAR_WIDTH)
bar = "█" * filled + "░" * (_PROGRESS_BAR_WIDTH - filled)
time_left = "N/A"
if finetune_job.progress.seconds_remaining > elapsed_time:
time_left = _human_readable_time(
finetune_job.progress.seconds_remaining - elapsed_time
)
time_text = f"{time_left} left"
progress = f"Progress: {bar} [bold]{percentage:>3.0f}%[/bold] [yellow]{time_text}[/yellow]"

if use_rich:
return progress

return re.sub(r"\[/?[^\]]+\]", "", progress)
19 changes: 19 additions & 0 deletions src/together/lib/types/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ class FinetuneJobStatus(str, Enum):
STATUS_COMPLETED = "completed"


COMPLETED_STATUSES = [
FinetuneJobStatus.STATUS_ERROR,
FinetuneJobStatus.STATUS_USER_ERROR,
FinetuneJobStatus.STATUS_COMPLETED,
FinetuneJobStatus.STATUS_CANCELLED,
]


class FinetuneEventType(str, Enum):
"""
Fine-tune job event types
Expand Down Expand Up @@ -260,6 +268,15 @@ class UnknownLRScheduler(BaseModel):
]


class FinetuneProgress(BaseModel):
"""
Fine-tune job progress
"""

estimate_available: bool = False
seconds_remaining: float = 0


class FinetuneResponse(BaseModel):
"""
Fine-tune API response type
Expand Down Expand Up @@ -393,6 +410,8 @@ class FinetuneResponse(BaseModel):
training_file_size: Optional[int] = Field(None, alias="TrainingFileSize")
train_on_inputs: Union[StrictBool, Literal["auto"], None] = "auto"

progress: Union[FinetuneProgress, None] = None

@classmethod
def validate_training_type(cls, v: TrainingType) -> TrainingType:
if v.type == "Full" or v.type == "":
Expand Down
Loading