Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow finetuning ESM2 with [un]frozen encoder #620

Merged
merged 5 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/docs/user-guide/examples/bionemo-esm2/finetune.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@
" --num-gpus 1 \\\n",
" --val-check-interval 10 \\\n",
" --log-every-n-steps 10 \\\n",
" --encoder-frozen \\\n",
" --lr 5e-3 \\\n",
" --lr-multiplier 1e2 \\\n",
" --scale-lr-layer \"regression_head\" \\\n",
Expand Down Expand Up @@ -697,6 +698,7 @@
" --num-gpus 1 \\\n",
" --val-check-interval 10 \\\n",
" --log-every-n-steps 10 \\\n",
" --encoder-frozen \\\n",
" --lr 5e-3 \\\n",
" --lr-multiplier 1e2 \\\n",
" --scale-lr-layer \"classification_head\" \\\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def train_model(
experiment_name: str,
resume_if_exists: bool,
precision: PrecisionTypes,
encoder_frozen: bool = False,
scale_lr_layer: Optional[str] = None,
lr_multiplier: float = 1.0,
wandb_entity: Optional[str] = None,
Expand All @@ -100,7 +101,7 @@ def train_model(
dataset_class: Type[InMemoryProteinDataset] = InMemorySingleValueDataset,
config_class: Type[BioBertConfig] = ESM2FineTuneSeqConfig,
metric_tracker: Callback | None = None,
overlap_grad_reduce: bool = True,
overlap_grad_reduce: bool = False, # Default to False to avoid communication issue in gradient synchronization step
farhadrgh marked this conversation as resolved.
Show resolved Hide resolved
overlap_param_gather: bool = True,
average_in_collective: bool = True,
grad_reduce_in_fp32: bool = False,
Expand All @@ -127,6 +128,7 @@ def train_model(
result_dir that stores the logs and checkpoints.
resume_if_exists (bool): attempt to resume if the checkpoint exists [FIXME @skothenhill this doesn't work yet]
precision (PrecisionTypes): Precision type for training (e.g., float16, float32)
encoder_frozen (bool): Freeze the encoder parameters. Default is False.
scale_lr_layer (Optional[str]): layer names for which the lr is scaled by lr_multiplier
lr_multiplier (float): lr multiplier for parameters in scale_lr_layer
wandb_entity (Optional[str]): The team posting this run (default: your username or your default team)
Expand Down Expand Up @@ -258,6 +260,7 @@ def train_model(
)
# Configure the model
config = config_class(
encoder_frozen=encoder_frozen,
params_dtype=get_autocast_dtype(precision),
pipeline_dtype=get_autocast_dtype(precision),
autocast_dtype=get_autocast_dtype(precision), # setting this speeds things up a lot
Expand Down Expand Up @@ -351,6 +354,7 @@ def finetune_esm2_entrypoint():
tensor_model_parallel_size=args.tensor_model_parallel_size,
accumulate_grad_batches=args.accumulate_grad_batches,
precision=args.precision,
encoder_frozen=args.encoder_frozen,
scale_lr_layer=args.scale_lr_layer,
lr_multiplier=args.lr_multiplier,
experiment_name=args.experiment_name,
Expand All @@ -365,7 +369,7 @@ def finetune_esm2_entrypoint():
nsys_ranks=args.nsys_ranks,
dataset_class=args.dataset_class,
config_class=args.config_class,
overlap_grad_reduce=not args.no_overlap_grad_reduce,
overlap_grad_reduce=args.overlap_grad_reduce,
overlap_param_gather=not args.no_overlap_param_gather,
average_in_collective=not args.no_average_in_collective,
grad_reduce_in_fp32=args.grad_reduce_in_fp32,
Expand Down Expand Up @@ -398,6 +402,12 @@ def get_parser():
default="bf16-mixed",
help="Precision type to use for training.",
)
parser.add_argument(
"--encoder-frozen",
action="store_true",
default=False,
help="Freeze the encoder parameters",
)
parser.add_argument(
"--lr",
type=float,
Expand Down Expand Up @@ -596,7 +606,7 @@ def get_parser():
)
# DDP config
parser.add_argument(
"--no-overlap-grad-reduce",
"--overlap-grad-reduce",
action="store_true",
default=False,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ def data_to_csv(data, tmp_path):


@pytest.mark.needs_gpu
@pytest.mark.parametrize("encoder_frozen", [True, False])
def test_esm2_finetune_token_classifier(
tmp_path,
dummy_data_per_token_classification_ft,
encoder_frozen,
n_steps_train: int = 50,
seed: int = 42,
):
Expand All @@ -71,6 +73,7 @@ def test_esm2_finetune_token_classifier(
accumulate_grad_batches=1,
resume_if_exists=False,
precision="bf16-mixed",
encoder_frozen=encoder_frozen,
dataset_class=InMemoryPerTokenValueDataset,
config_class=ESM2FineTuneTokenConfig,
metric_tracker=MetricTracker(metrics_to_track_val=["loss"], metrics_to_track_train=["loss"]),
Expand All @@ -85,13 +88,17 @@ def test_esm2_finetune_token_classifier(
encoder_requires_grad = [
p.requires_grad for name, p in trainer.model.named_parameters() if "classification_head" not in name
]
assert not all(encoder_requires_grad), "Pretrained model is not fully frozen during fine-tuning"
assert (
not all(encoder_requires_grad) == encoder_frozen
), f"Conflict in param requires_grad when encoder_frozen={encoder_frozen}"


@pytest.mark.needs_gpu
@pytest.mark.parametrize("encoder_frozen", [True, False])
def test_esm2_finetune_regressor(
tmp_path,
dummy_data_single_value_regression_ft,
encoder_frozen,
n_steps_train: int = 50,
seed: int = 42,
):
Expand All @@ -118,6 +125,7 @@ def test_esm2_finetune_regressor(
accumulate_grad_batches=1,
resume_if_exists=False,
precision="bf16-mixed",
encoder_frozen=encoder_frozen,
dataset_class=InMemorySingleValueDataset,
config_class=ESM2FineTuneSeqConfig,
metric_tracker=MetricTracker(metrics_to_track_val=["loss"], metrics_to_track_train=["loss"]),
Expand All @@ -132,7 +140,9 @@ def test_esm2_finetune_regressor(
encoder_requires_grad = [
p.requires_grad for name, p in trainer.model.named_parameters() if "regression_head" not in name
]
assert not all(encoder_requires_grad), "Pretrained model is not fully frozen during fine-tuning"
assert (
not all(encoder_requires_grad) == encoder_frozen
), f"Conflict in param requires_grad when encoder_frozen={encoder_frozen}"


@pytest.fixture
Expand Down Expand Up @@ -258,14 +268,19 @@ def test_get_parser():
"--nsys-ranks",
"0",
"1",
"--no-overlap-grad-reduce",
"--overlap-grad-reduce",
"--no-overlap-param-gather",
"--no-average-in-collective",
"--grad-reduce-in-fp32",
"--dataset-class",
"InMemoryPerTokenValueDataset",
"--config-class",
"ESM2FineTuneTokenConfig",
"--encoder-frozen",
"--lr-multiplier",
"1e2",
"--scale-lr-layer",
"dummy_layer",
]
)

Expand Down Expand Up @@ -307,9 +322,12 @@ def test_get_parser():
assert args.nsys_start_step == 10
assert args.nsys_end_step == 50
assert args.nsys_ranks == [0, 1]
assert args.no_overlap_grad_reduce is True
assert args.overlap_grad_reduce is True
assert args.no_overlap_param_gather is True
assert args.no_average_in_collective is True
assert args.grad_reduce_in_fp32 is True
assert args.dataset_class == InMemoryPerTokenValueDataset
assert args.config_class == ESM2FineTuneTokenConfig
assert args.encoder_frozen is True
assert args.lr_multiplier == 100
assert args.scale_lr_layer == "dummy_layer"
Loading