Skip to content

Commit a7cab74

Browse files
VLM Fine-tuning support (#201)
* Support VLM finetuning
1 parent 97c74a3 commit a7cab74

File tree

7 files changed

+358
-40
lines changed

7 files changed

+358
-40
lines changed

src/together/lib/cli/api/fine_tuning.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,12 @@ def fine_tuning(ctx: click.Context) -> None:
176176
help="Whether to mask the user messages in conversational data or prompts in instruction data. "
177177
"`auto` will automatically determine whether to mask the inputs based on the data format.",
178178
)
179+
@click.option(
180+
"--train-vision",
181+
type=bool,
182+
default=False,
183+
help="Whether to train the vision encoder. Only supported for multimodal models.",
184+
)
179185
@click.option(
180186
"--from-checkpoint",
181187
type=str,
@@ -231,6 +237,7 @@ def create(
231237
lora_dropout: float | None,
232238
lora_alpha: float | None,
233239
lora_trainable_modules: str | None,
240+
train_vision: bool,
234241
suffix: str | None,
235242
wandb_api_key: str | None,
236243
wandb_base_url: str | None,
@@ -272,6 +279,7 @@ def create(
272279
lora_dropout=lora_dropout,
273280
lora_alpha=lora_alpha,
274281
lora_trainable_modules=lora_trainable_modules,
282+
train_vision=train_vision,
275283
suffix=suffix,
276284
wandb_api_key=wandb_api_key,
277285
wandb_base_url=wandb_base_url,
@@ -363,6 +371,10 @@ def create(
363371
simpo_gamma=simpo_gamma or 0,
364372
)
365373

374+
if model_limits.supports_vision:
375+
# Don't show price estimation for multimodal models yet
376+
confirm = True
377+
366378
finetune_price_estimation_result = client.fine_tuning.estimate_price(
367379
training_file=training_file,
368380
validation_file=validation_file,

src/together/lib/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@
3737
# maximum number of GB sized files we support finetuning for
3838
MAX_FILE_SIZE_GB = 50.1
3939

40+
# Multimodal limits
41+
MAX_IMAGES_PER_EXAMPLE = 10
42+
MAX_IMAGE_BYTES = 10 * 1024 * 1024 # 10MB
43+
# Max length = Header length + base64 factor (4/3) * image bytes
44+
MAX_BASE64_IMAGE_LENGTH = len("data:image/jpeg;base64,") + 4 * MAX_IMAGE_BYTES // 3
45+
4046
# expected columns for Parquet files
4147
PARQUET_EXPECTED_COLUMNS = ["input_ids", "attention_mask", "labels"]
4248

src/together/lib/resources/fine_tuning.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
CosineLRSchedulerArgs,
2323
LinearLRSchedulerArgs,
2424
FinetuneTrainingLimits,
25+
FinetuneMultimodalParams,
2526
)
2627

2728
AVAILABLE_TRAINING_METHODS = {
@@ -51,6 +52,7 @@ def create_finetune_request(
5152
lora_dropout: float | None = 0,
5253
lora_alpha: float | None = None,
5354
lora_trainable_modules: str | None = "all-linear",
55+
train_vision: bool = False,
5456
suffix: str | None = None,
5557
wandb_api_key: str | None = None,
5658
wandb_base_url: str | None = None,
@@ -207,6 +209,13 @@ def create_finetune_request(
207209
simpo_gamma=simpo_gamma,
208210
)
209211

212+
if model_limits.supports_vision:
213+
multimodal_params = FinetuneMultimodalParams(train_vision=train_vision)
214+
elif not model_limits.supports_vision and train_vision:
215+
raise ValueError(f"Vision encoder training is not supported for the non-multimodal model `{model}`")
216+
else:
217+
multimodal_params = None
218+
210219
finetune_request = FinetuneRequest(
211220
model=model,
212221
training_file=training_file,
@@ -227,6 +236,7 @@ def create_finetune_request(
227236
wandb_project_name=wandb_project_name,
228237
wandb_name=wandb_name,
229238
training_method=training_method_cls, # pyright: ignore[reportPossiblyUnboundVariable]
239+
multimodal_params=multimodal_params,
230240
from_checkpoint=from_checkpoint,
231241
from_hf_model=from_hf_model,
232242
hf_model_revision=hf_model_revision,

src/together/lib/types/fine_tuning.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ class FinetuneTrainingLimits(BaseModel):
203203
min_learning_rate: float
204204
full_training: Optional[FinetuneFullTrainingLimits] = None
205205
lora_training: Optional[FinetuneLoraTrainingLimits] = None
206+
supports_vision: bool = False
206207

207208

208209
class LinearLRSchedulerArgs(BaseModel):
@@ -270,6 +271,14 @@ class UnknownLRScheduler(BaseModel):
270271
]
271272

272273

274+
class FinetuneMultimodalParams(BaseModel):
275+
"""
276+
Multimodal parameters
277+
"""
278+
279+
train_vision: bool = False
280+
281+
273282
class FinetuneProgress(BaseModel):
274283
"""
275284
Fine-tune job progress
@@ -305,6 +314,9 @@ class FinetuneResponse(BaseModel):
305314
from_checkpoint: Optional[str] = None
306315
"""Checkpoint used to continue training"""
307316

317+
multimodal_params: Optional[FinetuneMultimodalParams] = None
318+
"""Multimodal parameters"""
319+
308320
from_hf_model: Optional[str] = None
309321
"""Hugging Face Hub repo to start training from"""
310322

@@ -469,6 +481,9 @@ class FinetuneRequest(BaseModel):
469481
training_method: TrainingMethod = Field(default_factory=TrainingMethodSFT)
470482
# from step
471483
from_checkpoint: Union[str, None] = None
484+
# multimodal parameters
485+
multimodal_params: Union[FinetuneMultimodalParams, None] = None
486+
# hugging face related fields
472487
from_hf_model: Union[str, None] = None
473488
hf_model_revision: Union[str, None] = None
474489
# hf related fields

0 commit comments

Comments
 (0)