Skip to content

Commit

Permalink
Add --allow-tf32 perf tuning argument that can be used to enable tf32
Browse files Browse the repository at this point in the history
Defaults to keeping tf32 disabled.  This is because we haven't fully
verified training results with fp32 enabled.
  • Loading branch information
jannehellsten committed Feb 11, 2021
1 parent d3a616a commit f7e4867
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/train-help.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,6 @@ Options:
--fp32 BOOL Disable mixed-precision training
--nhwc BOOL Use NHWC memory format with FP16
--nobench BOOL Disable cuDNN benchmarking
--allow-tf32 BOOL Allow PyTorch to use TF32 internally
--workers INT Override number of DataLoader workers
--help Show this message and exit.
8 changes: 8 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def setup_training_loop_kwargs(
# Performance options (not included in desc).
fp32 = None, # Disable mixed-precision training: <bool>, default = False
nhwc = None, # Use NHWC memory format with FP16: <bool>, default = False
allow_tf32 = None, # Allow PyTorch to use TF32 for matmul and convolutions: <bool>, default = False
nobench = None, # Disable cuDNN benchmarking: <bool>, default = False
workers = None, # Override number of DataLoader workers: <int>, default = 3
):
Expand Down Expand Up @@ -343,6 +344,12 @@ def setup_training_loop_kwargs(
if nobench:
args.cudnn_benchmark = False

if allow_tf32 is None:
allow_tf32 = False
assert isinstance(allow_tf32, bool)
if allow_tf32:
args.allow_tf32 = True

if workers is not None:
assert isinstance(workers, int)
if not workers >= 1:
Expand Down Expand Up @@ -425,6 +432,7 @@ def convert(self, value, param, ctx):
@click.option('--fp32', help='Disable mixed-precision training', type=bool, metavar='BOOL')
@click.option('--nhwc', help='Use NHWC memory format with FP16', type=bool, metavar='BOOL')
@click.option('--nobench', help='Disable cuDNN benchmarking', type=bool, metavar='BOOL')
@click.option('--allow-tf32', help='Allow PyTorch to use TF32 internally', type=bool, metavar='BOOL')
@click.option('--workers', help='Override number of DataLoader workers', type=int, metavar='INT')

def main(ctx, outdir, dry_run, **config_kwargs):
Expand Down
3 changes: 3 additions & 0 deletions training/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def training_loop(
network_snapshot_ticks = 50, # How often to save network snapshots? None = disable.
resume_pkl = None, # Network pickle to resume training from.
cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark?
allow_tf32 = False, # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32?
abort_fn = None, # Callback function for determining whether to abort training. Must return consistent results across ranks.
progress_fn = None, # Callback function for updating training progress. Called for all ranks.
):
Expand All @@ -124,6 +125,8 @@ def training_loop(
np.random.seed(random_seed * num_gpus + rank)
torch.manual_seed(random_seed * num_gpus + rank)
torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed.
torch.backends.cuda.matmul.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for matmul
torch.backends.cudnn.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for convolutions
conv2d_gradfix.enabled = True # Improves training speed.
grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe.

Expand Down

0 comments on commit f7e4867

Please sign in to comment.