Skip to content

Commit

Permalink
Allow finetuning ESM2 with [un]frozen encoder (#620)
Browse files Browse the repository at this point in the history
### Description

The `overlap_grad_reduce=True` causes communication issue in gradient
synchronization step when the encoder parameters are **not** frozen.

```
AssertionError: Communication call has not been issued for this bucket (79/84 params have grad available)
```
This PR changes the default for `overlap_grad_reduce` and exposes
`config.encoder_frozen` to optionally [un]freeze the encoder parameters.

### Type of changes
<!-- Mark the relevant option with an [x] -->

- [x]  Bug fix (non-breaking change which fixes an issue)
- [x]  New feature (non-breaking change which adds functionality)
- [ ]  Refactor
- [ ]  Documentation update
- [ ]  Other (please describe):

### CI Pipeline Configuration
Configure CI behavior by applying the relevant labels:

-
[SKIP_CI](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/contributing/contributing.md#skip_ci)
- Skip all continuous integration tests
-
[INCLUDE_NOTEBOOKS_TESTS](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/contributing/contributing.md#include_notebooks_tests)
- Execute notebook validation tests in pytest

> [!NOTE]
> By default, the notebooks validation tests are skipped unless
explicitly enabled.

### Usage
<!--- How does a user interact with the changed code -->
```python
TODO: Add code snippet
```

### Pre-submit Checklist
<!--- Ensure all items are completed before submitting -->

 - [x] I have tested these changes locally
 - [x] I have updated the documentation accordingly
 - [x] I have added/updated tests as needed
 - [x] All existing tests pass successfully

---------

Signed-off-by: Farhad Ramezanghorbani <[email protected]>
Signed-off-by: Farhad Ramezanghorbani <[email protected]>
  • Loading branch information
farhadrgh authored Jan 22, 2025
1 parent e553389 commit a2fd916
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
2 changes: 2 additions & 0 deletions docs/docs/user-guide/examples/bionemo-esm2/finetune.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@
" --num-gpus 1 \\\n",
" --val-check-interval 10 \\\n",
" --log-every-n-steps 10 \\\n",
" --encoder-frozen \\\n",
" --lr 5e-3 \\\n",
" --lr-multiplier 1e2 \\\n",
" --scale-lr-layer \"regression_head\" \\\n",
Expand Down Expand Up @@ -697,6 +698,7 @@
" --num-gpus 1 \\\n",
" --val-check-interval 10 \\\n",
" --log-every-n-steps 10 \\\n",
" --encoder-frozen \\\n",
" --lr 5e-3 \\\n",
" --lr-multiplier 1e2 \\\n",
" --scale-lr-layer \"classification_head\" \\\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def train_model(
experiment_name: str,
resume_if_exists: bool,
precision: PrecisionTypes,
encoder_frozen: bool = False,
scale_lr_layer: Optional[str] = None,
lr_multiplier: float = 1.0,
wandb_entity: Optional[str] = None,
Expand All @@ -100,7 +101,7 @@ def train_model(
dataset_class: Type[InMemoryProteinDataset] = InMemorySingleValueDataset,
config_class: Type[BioBertConfig] = ESM2FineTuneSeqConfig,
metric_tracker: Callback | None = None,
overlap_grad_reduce: bool = True,
overlap_grad_reduce: bool = False, # Default to False to avoid communication issue in gradient synchronization step
overlap_param_gather: bool = True,
average_in_collective: bool = True,
grad_reduce_in_fp32: bool = False,
Expand All @@ -127,6 +128,7 @@ def train_model(
result_dir that stores the logs and checkpoints.
resume_if_exists (bool): attempt to resume if the checkpoint exists [FIXME @skothenhill this doesn't work yet]
precision (PrecisionTypes): Precision type for training (e.g., float16, float32)
encoder_frozen (bool): Freeze the encoder parameters. Default is False.
scale_lr_layer (Optional[str]): layer names for which the lr is scaled by lr_multiplier
lr_multiplier (float): lr multiplier for parameters in scale_lr_layer
wandb_entity (Optional[str]): The team posting this run (default: your username or your default team)
Expand Down Expand Up @@ -258,6 +260,7 @@ def train_model(
)
# Configure the model
config = config_class(
encoder_frozen=encoder_frozen,
params_dtype=get_autocast_dtype(precision),
pipeline_dtype=get_autocast_dtype(precision),
autocast_dtype=get_autocast_dtype(precision), # setting this speeds things up a lot
Expand Down Expand Up @@ -351,6 +354,7 @@ def finetune_esm2_entrypoint():
tensor_model_parallel_size=args.tensor_model_parallel_size,
accumulate_grad_batches=args.accumulate_grad_batches,
precision=args.precision,
encoder_frozen=args.encoder_frozen,
scale_lr_layer=args.scale_lr_layer,
lr_multiplier=args.lr_multiplier,
experiment_name=args.experiment_name,
Expand All @@ -365,7 +369,7 @@ def finetune_esm2_entrypoint():
nsys_ranks=args.nsys_ranks,
dataset_class=args.dataset_class,
config_class=args.config_class,
overlap_grad_reduce=not args.no_overlap_grad_reduce,
overlap_grad_reduce=args.overlap_grad_reduce,
overlap_param_gather=not args.no_overlap_param_gather,
average_in_collective=not args.no_average_in_collective,
grad_reduce_in_fp32=args.grad_reduce_in_fp32,
Expand Down Expand Up @@ -398,6 +402,12 @@ def get_parser():
default="bf16-mixed",
help="Precision type to use for training.",
)
parser.add_argument(
"--encoder-frozen",
action="store_true",
default=False,
help="Freeze the encoder parameters",
)
parser.add_argument(
"--lr",
type=float,
Expand Down Expand Up @@ -596,7 +606,7 @@ def get_parser():
)
# DDP config
parser.add_argument(
"--no-overlap-grad-reduce",
"--overlap-grad-reduce",
action="store_true",
default=False,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ def data_to_csv(data, tmp_path):


@pytest.mark.needs_gpu
@pytest.mark.parametrize("encoder_frozen", [True, False])
def test_esm2_finetune_token_classifier(
tmp_path,
dummy_data_per_token_classification_ft,
encoder_frozen,
n_steps_train: int = 50,
seed: int = 42,
):
Expand All @@ -71,6 +73,7 @@ def test_esm2_finetune_token_classifier(
accumulate_grad_batches=1,
resume_if_exists=False,
precision="bf16-mixed",
encoder_frozen=encoder_frozen,
dataset_class=InMemoryPerTokenValueDataset,
config_class=ESM2FineTuneTokenConfig,
metric_tracker=MetricTracker(metrics_to_track_val=["loss"], metrics_to_track_train=["loss"]),
Expand All @@ -85,13 +88,17 @@ def test_esm2_finetune_token_classifier(
encoder_requires_grad = [
p.requires_grad for name, p in trainer.model.named_parameters() if "classification_head" not in name
]
assert not all(encoder_requires_grad), "Pretrained model is not fully frozen during fine-tuning"
assert (
not all(encoder_requires_grad) == encoder_frozen
), f"Conflict in param requires_grad when encoder_frozen={encoder_frozen}"


@pytest.mark.needs_gpu
@pytest.mark.parametrize("encoder_frozen", [True, False])
def test_esm2_finetune_regressor(
tmp_path,
dummy_data_single_value_regression_ft,
encoder_frozen,
n_steps_train: int = 50,
seed: int = 42,
):
Expand All @@ -118,6 +125,7 @@ def test_esm2_finetune_regressor(
accumulate_grad_batches=1,
resume_if_exists=False,
precision="bf16-mixed",
encoder_frozen=encoder_frozen,
dataset_class=InMemorySingleValueDataset,
config_class=ESM2FineTuneSeqConfig,
metric_tracker=MetricTracker(metrics_to_track_val=["loss"], metrics_to_track_train=["loss"]),
Expand All @@ -132,7 +140,9 @@ def test_esm2_finetune_regressor(
encoder_requires_grad = [
p.requires_grad for name, p in trainer.model.named_parameters() if "regression_head" not in name
]
assert not all(encoder_requires_grad), "Pretrained model is not fully frozen during fine-tuning"
assert (
not all(encoder_requires_grad) == encoder_frozen
), f"Conflict in param requires_grad when encoder_frozen={encoder_frozen}"


@pytest.fixture
Expand Down Expand Up @@ -258,14 +268,19 @@ def test_get_parser():
"--nsys-ranks",
"0",
"1",
"--no-overlap-grad-reduce",
"--overlap-grad-reduce",
"--no-overlap-param-gather",
"--no-average-in-collective",
"--grad-reduce-in-fp32",
"--dataset-class",
"InMemoryPerTokenValueDataset",
"--config-class",
"ESM2FineTuneTokenConfig",
"--encoder-frozen",
"--lr-multiplier",
"1e2",
"--scale-lr-layer",
"dummy_layer",
]
)

Expand Down Expand Up @@ -307,9 +322,12 @@ def test_get_parser():
assert args.nsys_start_step == 10
assert args.nsys_end_step == 50
assert args.nsys_ranks == [0, 1]
assert args.no_overlap_grad_reduce is True
assert args.overlap_grad_reduce is True
assert args.no_overlap_param_gather is True
assert args.no_average_in_collective is True
assert args.grad_reduce_in_fp32 is True
assert args.dataset_class == InMemoryPerTokenValueDataset
assert args.config_class == ESM2FineTuneTokenConfig
assert args.encoder_frozen is True
assert args.lr_multiplier == 100
assert args.scale_lr_layer == "dummy_layer"

0 comments on commit a2fd916

Please sign in to comment.