Skip to content

Commit

Permalink
Untoggle bias fusions
Browse files Browse the repository at this point in the history
  • Loading branch information
jstjohn committed Nov 9, 2024
1 parent 152b38d commit cc97275
Showing 1 changed file with 12 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def main(
gc_interval: int = 0,
aligned_megatron_ddp: bool = False,
recompilation_check: bool = False,
bias_fusions: bool = False,
# TODO add datamodule class, and ability to change data step to get full support for pretraining workflows
) -> None:
"""Train a Geneformer model on single cell data.
Expand Down Expand Up @@ -147,6 +148,8 @@ def main(
good for clusters. This will likely slow down single node runs though.
recompilation_check (bool): enable a recompilation check (only do on a small run) to verify that fused gpu
kernels are not being regularly recompiled, which is very expensive, with a particular model/settings.
bias_fusions (bool): enable two bias fusions (dropout and activation) which improve performance but should be
evaluated for impacting training stability. At the very least they trigger recompilations.
"""
# Create the result directory if it does not exist.
if wandb_tags is None:
Expand Down Expand Up @@ -280,8 +283,8 @@ def main(
ffn_hidden_size=512,
num_attention_heads=4,
seq_length=seq_length,
bias_dropout_fusion=True, # TODO fix the recompilation issue, but for now it's faster even with recompilations
bias_activation_fusion=True, # TODO same note as above. Set these to False to see recompilation go away
bias_dropout_fusion=bias_fusions, # TODO fix the recompilation issue, but for now it's faster even with recompilations
bias_activation_fusion=bias_fusions, # TODO same note as above. Set these to False to see recompilation go away
params_dtype=get_autocast_dtype(precision),
pipeline_dtype=get_autocast_dtype(precision),
autocast_dtype=get_autocast_dtype(precision), # setting this speeds things up a lot
Expand Down Expand Up @@ -620,6 +623,12 @@ def config_class_type(desc: str) -> Type[BioBertConfig]:
help="Activate this and make sure a small training loop runs, this tells you that your settings are not "
"triggering regular recompilations which can be very expensive for fused gpu kernels.",
)
parser.add_argument(
"--bias-fusions",
action="store_true",
default=False,
help="Activate bias fusions which seem to reduce precision but are slightly faster.",
)

return parser

Expand Down Expand Up @@ -670,6 +679,7 @@ def entrypoint():
gc_interval=args.gc_interval,
aligned_megatron_ddp=args.aligned_megatron_ddp,
recompilation_check=args.recompilation_check,
bias_fusions=args.bias_fusions,
)


Expand Down

0 comments on commit cc97275

Please sign in to comment.