generated from mozilla-ai/Blueprint-template
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* [WIP] add UI config for finetuning in notebook * Add hf reference * Add config file using pydantic and yaml * Lint * Fix yaml issues * Convert TrainingConfig to dict * Fix EOF for yaml * Fix attribute name * Update import paths * Use os.cpu_count() instead of config arg
1 parent
ea5a1ae
commit 77cff00
Showing
6 changed files
with
128 additions
and
98 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import yaml | ||
from pydantic import BaseModel | ||
|
||
|
||
def load_config(config_path: str): | ||
with open(config_path, "r") as file: | ||
config_dict = yaml.safe_load(file) | ||
|
||
return Config(**config_dict) | ||
|
||
|
||
class TrainingConfig(BaseModel): | ||
""" | ||
More info at https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.Seq2SeqTrainingArguments | ||
""" | ||
|
||
push_to_hub: bool | ||
hub_private_repo: bool | ||
max_steps: int | ||
per_device_train_batch_size: int | ||
gradient_accumulation_steps: int | ||
learning_rate: float | ||
warmup_steps: int | ||
gradient_checkpointing: bool | ||
fp16: bool | ||
eval_strategy: str | ||
per_device_eval_batch_size: int | ||
predict_with_generate: bool | ||
generation_max_length: int | ||
save_steps: int | ||
logging_steps: int | ||
load_best_model_at_end: bool | ||
metric_for_best_model: str | ||
greater_is_better: bool | ||
|
||
|
||
class Config(BaseModel): | ||
""" | ||
Store configuration used for finetuning | ||
Args: | ||
model_id (str): HF model id of a Whisper model used for finetuning | ||
dataset_id (str): HF dataset id of a Common Voice dataset version, ideally from the mozilla-foundation repo | ||
language (str): registered language string that is supported by the Common Voice dataset | ||
repo_name (str | None): used both for local dir and HF, None will create a name based on the model and language id | ||
training_hp (TrainingConfig): store selective hyperparameter values from Seq2SeqTrainingArguments | ||
""" | ||
|
||
model_id: str | ||
dataset_id: str | ||
language: str | ||
repo_name: str | None | ||
training_hp: TrainingConfig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
model_id: openai/whisper-tiny | ||
dataset_id: mozilla-foundation/common_voice_17_0 | ||
language: Greek | ||
repo_name: None | ||
|
||
training_hp: | ||
push_to_hub: False | ||
hub_private_repo: True | ||
max_steps: 1 | ||
per_device_train_batch_size: 64 | ||
gradient_accumulation_steps: 1 | ||
learning_rate: 1e-5 | ||
warmup_steps: 50 | ||
gradient_checkpointing: True | ||
fp16: True | ||
eval_strategy: steps | ||
per_device_eval_batch_size: 8 | ||
predict_with_generate: True | ||
generation_max_length: 225 | ||
save_steps: 250 | ||
logging_steps: 25 | ||
load_best_model_at_end: True | ||
metric_for_best_model: wer | ||
greater_is_better: False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters