diff --git a/docs/train-help.txt b/docs/train-help.txt index d16519c99..31f38f9fe 100755 --- a/docs/train-help.txt +++ b/docs/train-help.txt @@ -65,5 +65,6 @@ Options: --fp32 BOOL Disable mixed-precision training --nhwc BOOL Use NHWC memory format with FP16 --nobench BOOL Disable cuDNN benchmarking + --allow-tf32 BOOL Allow PyTorch to use TF32 internally --workers INT Override number of DataLoader workers --help Show this message and exit. diff --git a/train.py b/train.py index dfb9b06b8..8d81b3f18 100755 --- a/train.py +++ b/train.py @@ -61,6 +61,7 @@ def setup_training_loop_kwargs( # Performance options (not included in desc). fp32 = None, # Disable mixed-precision training: , default = False nhwc = None, # Use NHWC memory format with FP16: , default = False + allow_tf32 = None, # Allow PyTorch to use TF32 for matmul and convolutions: , default = False nobench = None, # Disable cuDNN benchmarking: , default = False workers = None, # Override number of DataLoader workers: , default = 3 ): @@ -343,6 +344,12 @@ def setup_training_loop_kwargs( if nobench: args.cudnn_benchmark = False + if allow_tf32 is None: + allow_tf32 = False + assert isinstance(allow_tf32, bool) + if allow_tf32: + args.allow_tf32 = True + if workers is not None: assert isinstance(workers, int) if not workers >= 1: @@ -425,6 +432,7 @@ def convert(self, value, param, ctx): @click.option('--fp32', help='Disable mixed-precision training', type=bool, metavar='BOOL') @click.option('--nhwc', help='Use NHWC memory format with FP16', type=bool, metavar='BOOL') @click.option('--nobench', help='Disable cuDNN benchmarking', type=bool, metavar='BOOL') +@click.option('--allow-tf32', help='Allow PyTorch to use TF32 internally', type=bool, metavar='BOOL') @click.option('--workers', help='Override number of DataLoader workers', type=int, metavar='INT') def main(ctx, outdir, dry_run, **config_kwargs): diff --git a/training/training_loop.py b/training/training_loop.py index d25bbfe3b..14836ad2e 100755 --- a/training/training_loop.py +++ b/training/training_loop.py @@ -115,6 +115,7 @@ def training_loop( network_snapshot_ticks = 50, # How often to save network snapshots? None = disable. resume_pkl = None, # Network pickle to resume training from. cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark? + allow_tf32 = False, # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32? abort_fn = None, # Callback function for determining whether to abort training. Must return consistent results across ranks. progress_fn = None, # Callback function for updating training progress. Called for all ranks. ): @@ -124,6 +125,8 @@ def training_loop( np.random.seed(random_seed * num_gpus + rank) torch.manual_seed(random_seed * num_gpus + rank) torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed. + torch.backends.cuda.matmul.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for matmul + torch.backends.cudnn.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for convolutions conv2d_gradfix.enabled = True # Improves training speed. grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe.