Skip to content

Commit

Permalink
Make the faster option for bias fusions the default
Browse files Browse the repository at this point in the history
  • Loading branch information
jstjohn committed Nov 9, 2024
1 parent cc97275 commit 3be5b56
Showing 1 changed file with 7 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def main(
gc_interval: int = 0,
aligned_megatron_ddp: bool = False,
recompilation_check: bool = False,
bias_fusions: bool = False,
skip_bias_fusions: bool = False,
# TODO add datamodule class, and ability to change data step to get full support for pretraining workflows
) -> None:
"""Train a Geneformer model on single cell data.
Expand Down Expand Up @@ -148,10 +148,11 @@ def main(
good for clusters. This will likely slow down single node runs though.
recompilation_check (bool): enable a recompilation check (only do on a small run) to verify that fused gpu
kernels are not being regularly recompiled, which is very expensive, with a particular model/settings.
bias_fusions (bool): enable two bias fusions (dropout and activation) which improve performance but should be
evaluated for impacting training stability. At the very least they trigger recompilations.
skip_bias_fusions (bool): Disable the two bias fusions (dropout and activation) which improve performance but
cause recompilations. In testing they still seem to result in higher performance despite the recompilations.
"""
# Create the result directory if it does not exist.
bias_fusions = not skip_bias_fusions
if wandb_tags is None:
wandb_tags = []
result_dir.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -624,10 +625,10 @@ def config_class_type(desc: str) -> Type[BioBertConfig]:
"triggering regular recompilations which can be very expensive for fused gpu kernels.",
)
parser.add_argument(
"--bias-fusions",
"--skip-bias-fusions",
action="store_true",
default=False,
help="Activate bias fusions which seem to reduce precision but are slightly faster.",
help="Deactivate bias fusions which seem to reduce precision but are slightly faster.",
)

return parser
Expand Down Expand Up @@ -679,7 +680,7 @@ def entrypoint():
gc_interval=args.gc_interval,
aligned_megatron_ddp=args.aligned_megatron_ddp,
recompilation_check=args.recompilation_check,
bias_fusions=args.bias_fusions,
skip_bias_fusions=args.skip_bias_fusions,
)


Expand Down

0 comments on commit 3be5b56

Please sign in to comment.