diff --git a/docs/docs/user-guide/examples/bionemo-esm2/finetune.ipynb b/docs/docs/user-guide/examples/bionemo-esm2/finetune.ipynb index 5a85ea222..025b33123 100644 --- a/docs/docs/user-guide/examples/bionemo-esm2/finetune.ipynb +++ b/docs/docs/user-guide/examples/bionemo-esm2/finetune.ipynb @@ -528,6 +528,7 @@ " --num-gpus 1 \\\n", " --val-check-interval 10 \\\n", " --log-every-n-steps 10 \\\n", + " --encoder-frozen \\\n", " --lr 5e-3 \\\n", " --lr-multiplier 1e2 \\\n", " --scale-lr-layer \"regression_head\" \\\n", @@ -697,6 +698,7 @@ " --num-gpus 1 \\\n", " --val-check-interval 10 \\\n", " --log-every-n-steps 10 \\\n", + " --encoder-frozen \\\n", " --lr 5e-3 \\\n", " --lr-multiplier 1e2 \\\n", " --scale-lr-layer \"classification_head\" \\\n", diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/finetune_esm2.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/finetune_esm2.py index 5d7cda3a7..93c55a5f3 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/finetune_esm2.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/finetune_esm2.py @@ -76,6 +76,7 @@ def train_model( experiment_name: str, resume_if_exists: bool, precision: PrecisionTypes, + encoder_frozen: bool = False, scale_lr_layer: Optional[str] = None, lr_multiplier: float = 1.0, wandb_entity: Optional[str] = None, @@ -100,7 +101,7 @@ def train_model( dataset_class: Type[InMemoryProteinDataset] = InMemorySingleValueDataset, config_class: Type[BioBertConfig] = ESM2FineTuneSeqConfig, metric_tracker: Callback | None = None, - overlap_grad_reduce: bool = True, + overlap_grad_reduce: bool = False, # Default to False to avoid communication issue in gradient synchronization step overlap_param_gather: bool = True, average_in_collective: bool = True, grad_reduce_in_fp32: bool = False, @@ -127,6 +128,7 @@ def train_model( result_dir that stores the logs and checkpoints. resume_if_exists (bool): attempt to resume if the checkpoint exists [FIXME @skothenhill this doesn't work yet] precision (PrecisionTypes): Precision type for training (e.g., float16, float32) + encoder_frozen (bool): Freeze the encoder parameters. Default is False. scale_lr_layer (Optional[str]): layer names for which the lr is scaled by lr_multiplier lr_multiplier (float): lr multiplier for parameters in scale_lr_layer wandb_entity (Optional[str]): The team posting this run (default: your username or your default team) @@ -258,6 +260,7 @@ def train_model( ) # Configure the model config = config_class( + encoder_frozen=encoder_frozen, params_dtype=get_autocast_dtype(precision), pipeline_dtype=get_autocast_dtype(precision), autocast_dtype=get_autocast_dtype(precision), # setting this speeds things up a lot @@ -351,6 +354,7 @@ def finetune_esm2_entrypoint(): tensor_model_parallel_size=args.tensor_model_parallel_size, accumulate_grad_batches=args.accumulate_grad_batches, precision=args.precision, + encoder_frozen=args.encoder_frozen, scale_lr_layer=args.scale_lr_layer, lr_multiplier=args.lr_multiplier, experiment_name=args.experiment_name, @@ -365,7 +369,7 @@ def finetune_esm2_entrypoint(): nsys_ranks=args.nsys_ranks, dataset_class=args.dataset_class, config_class=args.config_class, - overlap_grad_reduce=not args.no_overlap_grad_reduce, + overlap_grad_reduce=args.overlap_grad_reduce, overlap_param_gather=not args.no_overlap_param_gather, average_in_collective=not args.no_average_in_collective, grad_reduce_in_fp32=args.grad_reduce_in_fp32, @@ -398,6 +402,12 @@ def get_parser(): default="bf16-mixed", help="Precision type to use for training.", ) + parser.add_argument( + "--encoder-frozen", + action="store_true", + default=False, + help="Freeze the encoder parameters", + ) parser.add_argument( "--lr", type=float, @@ -596,7 +606,7 @@ def get_parser(): ) # DDP config parser.add_argument( - "--no-overlap-grad-reduce", + "--overlap-grad-reduce", action="store_true", default=False, ) diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_finetune_esm2.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_finetune_esm2.py index 631064f32..9d0075bb1 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_finetune_esm2.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_finetune_esm2.py @@ -42,9 +42,11 @@ def data_to_csv(data, tmp_path): @pytest.mark.needs_gpu +@pytest.mark.parametrize("encoder_frozen", [True, False]) def test_esm2_finetune_token_classifier( tmp_path, dummy_data_per_token_classification_ft, + encoder_frozen, n_steps_train: int = 50, seed: int = 42, ): @@ -71,6 +73,7 @@ def test_esm2_finetune_token_classifier( accumulate_grad_batches=1, resume_if_exists=False, precision="bf16-mixed", + encoder_frozen=encoder_frozen, dataset_class=InMemoryPerTokenValueDataset, config_class=ESM2FineTuneTokenConfig, metric_tracker=MetricTracker(metrics_to_track_val=["loss"], metrics_to_track_train=["loss"]), @@ -85,13 +88,17 @@ def test_esm2_finetune_token_classifier( encoder_requires_grad = [ p.requires_grad for name, p in trainer.model.named_parameters() if "classification_head" not in name ] - assert not all(encoder_requires_grad), "Pretrained model is not fully frozen during fine-tuning" + assert ( + not all(encoder_requires_grad) == encoder_frozen + ), f"Conflict in param requires_grad when encoder_frozen={encoder_frozen}" @pytest.mark.needs_gpu +@pytest.mark.parametrize("encoder_frozen", [True, False]) def test_esm2_finetune_regressor( tmp_path, dummy_data_single_value_regression_ft, + encoder_frozen, n_steps_train: int = 50, seed: int = 42, ): @@ -118,6 +125,7 @@ def test_esm2_finetune_regressor( accumulate_grad_batches=1, resume_if_exists=False, precision="bf16-mixed", + encoder_frozen=encoder_frozen, dataset_class=InMemorySingleValueDataset, config_class=ESM2FineTuneSeqConfig, metric_tracker=MetricTracker(metrics_to_track_val=["loss"], metrics_to_track_train=["loss"]), @@ -132,7 +140,9 @@ def test_esm2_finetune_regressor( encoder_requires_grad = [ p.requires_grad for name, p in trainer.model.named_parameters() if "regression_head" not in name ] - assert not all(encoder_requires_grad), "Pretrained model is not fully frozen during fine-tuning" + assert ( + not all(encoder_requires_grad) == encoder_frozen + ), f"Conflict in param requires_grad when encoder_frozen={encoder_frozen}" @pytest.fixture @@ -258,7 +268,7 @@ def test_get_parser(): "--nsys-ranks", "0", "1", - "--no-overlap-grad-reduce", + "--overlap-grad-reduce", "--no-overlap-param-gather", "--no-average-in-collective", "--grad-reduce-in-fp32", @@ -266,6 +276,11 @@ def test_get_parser(): "InMemoryPerTokenValueDataset", "--config-class", "ESM2FineTuneTokenConfig", + "--encoder-frozen", + "--lr-multiplier", + "1e2", + "--scale-lr-layer", + "dummy_layer", ] ) @@ -307,9 +322,12 @@ def test_get_parser(): assert args.nsys_start_step == 10 assert args.nsys_end_step == 50 assert args.nsys_ranks == [0, 1] - assert args.no_overlap_grad_reduce is True + assert args.overlap_grad_reduce is True assert args.no_overlap_param_gather is True assert args.no_average_in_collective is True assert args.grad_reduce_in_fp32 is True assert args.dataset_class == InMemoryPerTokenValueDataset assert args.config_class == ESM2FineTuneTokenConfig + assert args.encoder_frozen is True + assert args.lr_multiplier == 100 + assert args.scale_lr_layer == "dummy_layer"