Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LR multiplier for ESM2 finetuning layers #609

Merged
merged 3 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion docs/docs/user-guide/examples/bionemo-esm2/finetune.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,14 @@
"```bash\n",
"finetune_esm2 --help \n",
"```\n",
"For a detailed description of training loop and the arguments please refer to the [ESM-2 Pretraining](./pretrain.md) tutorial."
"\n",
"For a detailed description of training loop and the arguments please refer to the [ESM-2 Pretraining](./pretrain.md) tutorial.\n",
"\n",
"#### Scaled LR for fine-tune head parameters \n",
"We can assign a different LR for specific layers (e.g. task head) during fine-tuning by making it possible to specify the name of the target layer as well as the LR multiplier.\n",
"\n",
"- `--lr-multiplier`: is a float that scales `--lr`\n",
"- `--sclae-lr-layer`: is the name of the layers for which we scale the LR"
]
},
{
Expand Down Expand Up @@ -522,6 +529,8 @@
" --val-check-interval 10 \\\n",
" --log-every-n-steps 10 \\\n",
" --lr 5e-3 \\\n",
" --lr-multiplier 1e2 \\\n",
" --scale-lr-layer \"regression_head\" \\\n",
" --result-dir {work_dir} \\\n",
" --micro-batch-size 2 \\\n",
" --num-gpus 1 \\\n",
Expand Down Expand Up @@ -689,6 +698,8 @@
" --val-check-interval 10 \\\n",
" --log-every-n-steps 10 \\\n",
" --lr 5e-3 \\\n",
" --lr-multiplier 1e2 \\\n",
" --scale-lr-layer \"classification_head\" \\\n",
" --result-dir {work_dir} \\\n",
" --micro-batch-size 2 \\\n",
" --num-gpus 1 \\\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def train_model(
experiment_name: str,
resume_if_exists: bool,
precision: PrecisionTypes,
scale_lr_layer: Optional[str] = None,
lr_multiplier: float = 1.0,
wandb_entity: Optional[str] = None,
wandb_project: Optional[str] = None,
wandb_offline: bool = False,
Expand Down Expand Up @@ -125,6 +127,8 @@ def train_model(
result_dir that stores the logs and checkpoints.
resume_if_exists (bool): attempt to resume if the checkpoint exists [FIXME @skothenhill this doesn't work yet]
precision (PrecisionTypes): Precision type for training (e.g., float16, float32)
scale_lr_layer (Optional[str]): layer names for which the lr is scaled by lr_multiplier
lr_multiplier (float): lr multiplier for parameters in scale_lr_layer
wandb_entity (Optional[str]): The team posting this run (default: your username or your default team)
wandb_project (Optional[str]): The name of the project to which this run will belong
wandb_offline (bool): Run offline (data can be streamed later to wandb servers).
Expand Down Expand Up @@ -271,8 +275,13 @@ def train_model(
weight_decay=0.01,
adam_beta1=0.9,
adam_beta2=0.98,
)
),
)
# fiddle is not serializing lambda fn
# to bypass serialization of lambda fn scale_lr_condition as part of optimizer configuration
if scale_lr_layer:
optimizer.scale_lr_cond = lambda name, param: scale_lr_layer in name
optimizer.lr_mult = lr_multiplier

module = biobert_lightning_module(config=config, tokenizer=tokenizer, optimizer=optimizer)

Expand Down Expand Up @@ -342,6 +351,8 @@ def finetune_esm2_entrypoint():
tensor_model_parallel_size=args.tensor_model_parallel_size,
accumulate_grad_batches=args.accumulate_grad_batches,
precision=args.precision,
scale_lr_layer=args.scale_lr_layer,
lr_multiplier=args.lr_multiplier,
experiment_name=args.experiment_name,
resume_if_exists=args.resume_if_exists,
restore_from_checkpoint_path=args.restore_from_checkpoint_path,
Expand Down Expand Up @@ -394,6 +405,20 @@ def get_parser():
default=4e-4,
help="Learning rate for training. Default is 4e-4",
)
parser.add_argument(
"--scale-lr-layer",
type=str,
required=False,
default=None,
help="Layer name for which we scale the lr by lr-multiplier",
)
parser.add_argument(
"--lr-multiplier",
type=float,
required=False,
default=1.0,
help="Learning rate multiplier for layers with scale-lr-layer in their name",
)
parser.add_argument(
"--create-tensorboard-logger", action="store_true", default=False, help="Create a tensorboard logger."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def test_esm2_finetune_token_classifier(
log_every_n_steps=n_steps_train // 2,
num_dataset_workers=10,
lr=1e-5,
scale_lr_layer="classification_head",
lr_multiplier=1e2,
micro_batch_size=4,
accumulate_grad_batches=1,
resume_if_exists=False,
Expand Down Expand Up @@ -114,6 +116,8 @@ def test_esm2_finetune_regressor(
log_every_n_steps=n_steps_train // 2,
num_dataset_workers=10,
lr=1e-5,
scale_lr_layer="regression_head",
lr_multiplier=1e2,
micro_batch_size=4,
accumulate_grad_batches=1,
resume_if_exists=False,
Expand Down
Loading