From 01fbc8cda6ff9694f6e93689bb3c314dbf267629 Mon Sep 17 00:00:00 2001 From: sichu Date: Wed, 1 Jan 2025 14:04:40 +0000 Subject: [PATCH] add weight decay Signed-off-by: sichu --- .../src/bionemo/esm2/scripts/train_esm2.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py index 847da1dd0b..86989c6baa 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py @@ -96,6 +96,7 @@ def main( overlap_param_gather: bool = False, # TODO waiting for a NeMo fix average_in_collective: bool = True, grad_reduce_in_fp32: bool = False, + weight_decay: float = 0.01, ) -> None: """Train an ESM2 model on UR data. @@ -155,6 +156,7 @@ def main( overlap_param_gather (bool): overlap parameter gather average_in_collective (bool): average in collective grad_reduce_in_fp32 (bool): gradient reduction in fp32 + weight_decay (float): weight decay of the model """ # Create the result directory if it does not exist. result_dir.mkdir(parents=True, exist_ok=True) @@ -283,7 +285,7 @@ def main( lr=lr, optimizer="adam", use_distributed_optimizer=True, - weight_decay=0.01, + weight_decay=weight_decay, adam_beta1=0.9, adam_beta2=0.98, ), @@ -387,6 +389,7 @@ def train_esm2_entrypoint(): overlap_param_gather=args.overlap_param_gather, average_in_collective=not args.no_average_in_collective, grad_reduce_in_fp32=args.grad_reduce_in_fp32, + weight_decay=args.weight_decay, ) @@ -694,6 +697,13 @@ def get_parser(): default=4 * 1280, help="FFN hidden size of the model. Default is 4 * 1280.", ) + parser.add_argument( + "--weight-decay", + type=float, + required=False, + default=0.01, + help="Weight decay of the model. Default is 0.01.", + ) # DDP config parser.add_argument( "--no-overlap-grad-reduce",