diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py index 87baedf35c..3d80028aa5 100644 --- a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py @@ -93,7 +93,7 @@ def main( gc_interval: int = 0, aligned_megatron_ddp: bool = False, recompilation_check: bool = False, - bias_fusions: bool = False, + skip_bias_fusions: bool = False, # TODO add datamodule class, and ability to change data step to get full support for pretraining workflows ) -> None: """Train a Geneformer model on single cell data. @@ -148,10 +148,11 @@ def main( good for clusters. This will likely slow down single node runs though. recompilation_check (bool): enable a recompilation check (only do on a small run) to verify that fused gpu kernels are not being regularly recompiled, which is very expensive, with a particular model/settings. - bias_fusions (bool): enable two bias fusions (dropout and activation) which improve performance but should be - evaluated for impacting training stability. At the very least they trigger recompilations. + skip_bias_fusions (bool): Disable the two bias fusions (dropout and activation) which improve performance but + cause recompilations. In testing they still seem to result in higher performance despite the recompilations. """ # Create the result directory if it does not exist. + bias_fusions = not skip_bias_fusions if wandb_tags is None: wandb_tags = [] result_dir.mkdir(parents=True, exist_ok=True) @@ -624,10 +625,10 @@ def config_class_type(desc: str) -> Type[BioBertConfig]: "triggering regular recompilations which can be very expensive for fused gpu kernels.", ) parser.add_argument( - "--bias-fusions", + "--skip-bias-fusions", action="store_true", default=False, - help="Activate bias fusions which seem to reduce precision but are slightly faster.", + help="Deactivate bias fusions which seem to reduce precision but are slightly faster.", ) return parser @@ -679,7 +680,7 @@ def entrypoint(): gc_interval=args.gc_interval, aligned_megatron_ddp=args.aligned_megatron_ddp, recompilation_check=args.recompilation_check, - bias_fusions=args.bias_fusions, + skip_bias_fusions=args.skip_bias_fusions, )