From e74827f4926cf269610d017827648bbf817bae17 Mon Sep 17 00:00:00 2001 From: Janne Hellsten Date: Fri, 22 Apr 2022 17:35:34 +0300 Subject: [PATCH 01/13] pytorch 1.11 support: don't use conv2d_gradfix on v1.11, port grid_sample_gradfix to the new API thanks @timothybrooks for the fix! for #145 --- torch_utils/ops/conv2d_gradfix.py | 5 +++++ torch_utils/ops/grid_sample_gradfix.py | 8 +++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/torch_utils/ops/conv2d_gradfix.py b/torch_utils/ops/conv2d_gradfix.py index e95e10d0b..bb70f1ea6 100755 --- a/torch_utils/ops/conv2d_gradfix.py +++ b/torch_utils/ops/conv2d_gradfix.py @@ -12,6 +12,7 @@ import warnings import contextlib import torch +from pkg_resources import parse_version # pylint: disable=redefined-builtin # pylint: disable=arguments-differ @@ -21,6 +22,7 @@ enabled = False # Enable the custom op by setting this to true. weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. +_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 @contextlib.contextmanager def no_weight_gradients(): @@ -48,6 +50,9 @@ def _should_use_custom_op(input): assert isinstance(input, torch.Tensor) if (not enabled) or (not torch.backends.cudnn.enabled): return False + if _use_pytorch_1_11_api: + # The work-around code doesn't work on PyTorch 1.11.0 onwards + return False if input.device.type != 'cuda': return False if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): diff --git a/torch_utils/ops/grid_sample_gradfix.py b/torch_utils/ops/grid_sample_gradfix.py index ca6b3413e..9f1b7e9c8 100755 --- a/torch_utils/ops/grid_sample_gradfix.py +++ b/torch_utils/ops/grid_sample_gradfix.py @@ -13,6 +13,7 @@ import warnings import torch +from pkg_resources import parse_version # pylint: disable=redefined-builtin # pylint: disable=arguments-differ @@ -21,6 +22,7 @@ #---------------------------------------------------------------------------- enabled = False # Enable the custom op by setting this to true. +_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 #---------------------------------------------------------------------------- @@ -62,7 +64,11 @@ class _GridSample2dBackward(torch.autograd.Function): @staticmethod def forward(ctx, grad_output, input, grid): op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') - grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) + if _use_pytorch_1_11_api: + output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask) + else: + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) ctx.save_for_backward(grid) return grad_input, grad_grid From 4397f15bfdc4176900eb8471ec8380e959cb5d8b Mon Sep 17 00:00:00 2001 From: Janne Hellsten Date: Wed, 26 Apr 2023 14:40:16 +0300 Subject: [PATCH 02/13] Fix custom ops bug for pytorch 1.12 and onwards Adapt to newer _jit_get_operation API that changed in https://github.com/pytorch/pytorch/pull/76814 for #188, #193 --- torch_utils/ops/grid_sample_gradfix.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch_utils/ops/grid_sample_gradfix.py b/torch_utils/ops/grid_sample_gradfix.py index 9f1b7e9c8..538ee400a 100755 --- a/torch_utils/ops/grid_sample_gradfix.py +++ b/torch_utils/ops/grid_sample_gradfix.py @@ -23,6 +23,7 @@ enabled = False # Enable the custom op by setting this to true. _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 +_use_pytorch_1_12_api = parse_version(torch.__version__) >= parse_version('1.12.0a') # Allow prerelease builds of 1.12 #---------------------------------------------------------------------------- @@ -64,6 +65,8 @@ class _GridSample2dBackward(torch.autograd.Function): @staticmethod def forward(ctx, grad_output, input, grid): op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') + if _use_pytorch_1_12_api: + op = op[0] if _use_pytorch_1_11_api: output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask) From 471c7cba67fe82656c33a86df15dc0beae4ed5c4 Mon Sep 17 00:00:00 2001 From: Wok Date: Wed, 10 Jan 2024 14:27:50 +0100 Subject: [PATCH 03/13] Support newer versions of PyTorch (v1.1X and v2) --- torch_utils/ops/grid_sample_gradfix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_utils/ops/grid_sample_gradfix.py b/torch_utils/ops/grid_sample_gradfix.py index 538ee400a..98b5b97b6 100755 --- a/torch_utils/ops/grid_sample_gradfix.py +++ b/torch_utils/ops/grid_sample_gradfix.py @@ -37,7 +37,7 @@ def grid_sample(input, grid): def _should_use_custom_op(): if not enabled: return False - if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): + if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9', '1.1', '2']): return True warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') return False From edc9f16e3e6cdaa7e345c995821d0284d31f1509 Mon Sep 17 00:00:00 2001 From: Wok Date: Mon, 1 Feb 2021 13:25:11 +0100 Subject: [PATCH 04/13] Save output as JPG instead of PNG --- training/training_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/training/training_loop.py b/training/training_loop.py index 14836ad2e..8d137c189 100755 --- a/training/training_loop.py +++ b/training/training_loop.py @@ -220,11 +220,11 @@ def training_loop( if rank == 0: print('Exporting sample images...') grid_size, images, labels = setup_snapshot_image_grid(training_set=training_set) - save_image_grid(images, os.path.join(run_dir, 'reals.png'), drange=[0,255], grid_size=grid_size) + save_image_grid(images, os.path.join(run_dir, 'reals.jpg'), drange=[0,255], grid_size=grid_size) grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu) grid_c = torch.from_numpy(labels).to(device).split(batch_gpu) images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy() - save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size) + save_image_grid(images, os.path.join(run_dir, 'fakes_init.jpg'), drange=[-1,1], grid_size=grid_size) # Initialize logs. if rank == 0: @@ -347,7 +347,7 @@ def training_loop( # Save image snapshot. if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0): images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy() - save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1,1], grid_size=grid_size) + save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.jpg'), drange=[-1,1], grid_size=grid_size) # Save network snapshot. snapshot_pkl = None From bd645cffc7e43f2b70a61841ecfa608a2e3c0b61 Mon Sep 17 00:00:00 2001 From: Wok Date: Mon, 1 Feb 2021 13:29:05 +0100 Subject: [PATCH 05/13] Add utility functions --- training/misc.py | 66 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 training/misc.py diff --git a/training/misc.py b/training/misc.py new file mode 100644 index 000000000..d40cc2069 --- /dev/null +++ b/training/misc.py @@ -0,0 +1,66 @@ +import glob +import os +import re + +from pathlib import Path + +def get_parent_dir(run_dir): + out_dir = Path(run_dir).parent + + return out_dir + +def locate_latest_pkl(out_dir): + all_pickle_names = sorted(glob.glob(os.path.join(out_dir, '0*', 'network-*.pkl'))) + + try: + latest_pickle_name = all_pickle_names[-1] + except IndexError: + latest_pickle_name = None + + return latest_pickle_name + +def parse_kimg_from_network_name(network_pickle_name): + + if network_pickle_name is not None: + resume_run_id = os.path.basename(os.path.dirname(network_pickle_name)) + RE_KIMG = re.compile('network-snapshot-(\d+).pkl') + try: + kimg = int(RE_KIMG.match(os.path.basename(network_pickle_name)).group(1)) + except AttributeError: + kimg = 0.0 + else: + kimg = 0.0 + + return float(kimg) + + +def parse_augment_p_from_log(network_pickle_name): + + if network_pickle_name is not None: + network_folder_name = os.path.dirname(network_pickle_name) + log_file_name = network_folder_name + "/log.txt" + + try: + with open(log_file_name, "r") as f: + # Tokenize each line starting with the word 'tick' + lines = [ + l.strip().split() for l in f.readlines() if l.startswith("tick") + ] + except FileNotFoundError: + lines = [] + + # Extract the last token of each line for which the second to last token is 'augment' + values = [ + tokens[-1] + for tokens in lines + if len(tokens) > 1 and tokens[-2] == "augment" + ] + + if len(values)>0: + augment_p = float(values[-1]) + else: + augment_p = 0.0 + else: + augment_p = 0.0 + + return float(augment_p) From 61bc29cc3a2b6aa02d0afe224d2dd85329a9dcea Mon Sep 17 00:00:00 2001 From: Wok Date: Mon, 1 Feb 2021 13:34:43 +0100 Subject: [PATCH 06/13] Resume from the latest pickle --- training/training_loop.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/training/training_loop.py b/training/training_loop.py index 8d137c189..91addf3d8 100755 --- a/training/training_loop.py +++ b/training/training_loop.py @@ -23,6 +23,7 @@ import legacy from metrics import metric_main +from training import misc as tmisc #---------------------------------------------------------------------------- @@ -152,6 +153,10 @@ def training_loop( G_ema = copy.deepcopy(G).eval() # Resume from existing pickle. + if resume_pkl == 'latest': + out_dir = tmisc.get_parent_dir(run_dir) + resume_pkl = tmisc.locate_latest_pkl(out_dir) + if (resume_pkl is not None) and (rank == 0): print(f'Resuming from "{resume_pkl}"') with dnnlib.util.open_url(resume_pkl) as f: From 7eeedaf323938da31ccd907dd49674b006ff0bd0 Mon Sep 17 00:00:00 2001 From: Wok Date: Mon, 1 Feb 2021 13:38:10 +0100 Subject: [PATCH 07/13] Automatically set the resume value of kimg --- training/training_loop.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/training/training_loop.py b/training/training_loop.py index 91addf3d8..04ee1fd7b 100755 --- a/training/training_loop.py +++ b/training/training_loop.py @@ -157,6 +157,10 @@ def training_loop( out_dir = tmisc.get_parent_dir(run_dir) resume_pkl = tmisc.locate_latest_pkl(out_dir) + resume_kimg = tmisc.parse_kimg_from_network_name(resume_pkl) + if resume_kimg > 0: + print(f'Resuming from kimg = {resume_kimg}') + if (resume_pkl is not None) and (rank == 0): print(f'Resuming from "{resume_pkl}"') with dnnlib.util.open_url(resume_pkl) as f: @@ -250,14 +254,14 @@ def training_loop( if rank == 0: print(f'Training for {total_kimg} kimg...') print() - cur_nimg = 0 + cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() maintenance_time = tick_start_time - start_time batch_idx = 0 if progress_fn is not None: - progress_fn(0, total_kimg) + progress_fn(int(resume_kimg), total_kimg) while True: # Fetch training data. From 0e4e93e80059e733be35aaa44b50743722efdc32 Mon Sep 17 00:00:00 2001 From: Wok Date: Mon, 1 Feb 2021 13:39:24 +0100 Subject: [PATCH 08/13] Allow to manually set the resume value of the augmentation strength --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 8d81b3f18..e20f3ca9f 100755 --- a/train.py +++ b/train.py @@ -244,8 +244,8 @@ def setup_training_loop_kwargs( if p is not None: assert isinstance(p, float) - if aug != 'fixed': - raise UserError('--p can only be specified with --aug=fixed') + if resume != 'latest' and aug != 'fixed': + raise UserError('--p can only be specified with --resume=latest or --aug=fixed') if not 0 <= p <= 1: raise UserError('--p must be between 0 and 1') desc += f'-p{p:g}' From 7e29bb5f9f3e129b620a6a962135df58b1c53e90 Mon Sep 17 00:00:00 2001 From: Wok Date: Mon, 1 Feb 2021 22:02:30 +0100 Subject: [PATCH 09/13] Automatically set the resume value of augment_p --- training/training_loop.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/training/training_loop.py b/training/training_loop.py index 04ee1fd7b..5c2a17892 100755 --- a/training/training_loop.py +++ b/training/training_loop.py @@ -161,6 +161,12 @@ def training_loop( if resume_kimg > 0: print(f'Resuming from kimg = {resume_kimg}') + if ada_target is not None and augment_p == 0: + # Overwrite augment_p only if the augmentation probability is not fixed by the user + augment_p = tmisc.parse_augment_p_from_log(resume_pkl) + if augment_p > 0: + print(f'Resuming with augment_p = {augment_p}') + if (resume_pkl is not None) and (rank == 0): print(f'Resuming from "{resume_pkl}"') with dnnlib.util.open_url(resume_pkl) as f: From 2b82b6a8dfcd05b67823bf70fab7166888c9edae Mon Sep 17 00:00:00 2001 From: Wok Date: Mon, 1 Feb 2021 13:44:49 +0100 Subject: [PATCH 10/13] Add cfg (auto_norp): auto cfg without EMA rampup --- train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index e20f3ca9f..9982b1c4b 100755 --- a/train.py +++ b/train.py @@ -153,6 +153,7 @@ def setup_training_loop_kwargs( cfg_specs = { 'auto': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2), # Populated dynamically based on resolution and GPU count. + 'auto_norp': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=None, map=2), 'stylegan2': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # Uses mixed-precision, unlike the original StyleGAN2. 'paper256': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=0.5, lrate=0.0025, gamma=1, ema=20, ramp=None, map=8), 'paper512': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=1, lrate=0.0025, gamma=0.5, ema=20, ramp=None, map=8), @@ -162,7 +163,7 @@ def setup_training_loop_kwargs( assert cfg in cfg_specs spec = dnnlib.EasyDict(cfg_specs[cfg]) - if cfg == 'auto': + if cfg.startswith('auto'): desc += f'{gpus:d}' spec.ref_gpus = gpus res = args.training_set_kwargs.resolution @@ -413,7 +414,7 @@ def convert(self, value, param, ctx): @click.option('--mirror', help='Enable dataset x-flips [default: false]', type=bool, metavar='BOOL') # Base config. -@click.option('--cfg', help='Base config [default: auto]', type=click.Choice(['auto', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar'])) +@click.option('--cfg', help='Base config [default: auto]', type=click.Choice(['auto', 'auto_norp', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar'])) @click.option('--gamma', help='Override R1 gamma', type=float) @click.option('--kimg', help='Override training duration', type=int, metavar='INT') @click.option('--batch', help='Override batch size', type=int, metavar='INT') From a51de587cffac5d122f9cc23fd2ced4c4ef74c0d Mon Sep 17 00:00:00 2001 From: Wok Date: Mon, 1 Feb 2021 13:48:56 +0100 Subject: [PATCH 11/13] Allow to override mapping net depth with --cfg_map --- train.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/train.py b/train.py index 9982b1c4b..754a06003 100755 --- a/train.py +++ b/train.py @@ -47,6 +47,7 @@ def setup_training_loop_kwargs( gamma = None, # Override R1 gamma: kimg = None, # Override training duration: batch = None, # Override batch size: + cfg_map = None, # Override config map: , default = depends on cfg # Discriminator augmentation. aug = None, # Augmentation mode: 'ada' (default), 'noaug', 'fixed' @@ -220,6 +221,12 @@ def setup_training_loop_kwargs( args.batch_size = batch args.batch_gpu = batch // gpus + if cfg_map is not None: + assert isinstance(cfg_map, int) + if not cfg_map >= 1: + raise UserError('--cfg_map must be at least 1') + args.G_kwargs.mapping_kwargs.num_layers = cfg_map + # --------------------------------------------------- # Discriminator augmentation: aug, p, target, augpipe # --------------------------------------------------- @@ -418,6 +425,7 @@ def convert(self, value, param, ctx): @click.option('--gamma', help='Override R1 gamma', type=float) @click.option('--kimg', help='Override training duration', type=int, metavar='INT') @click.option('--batch', help='Override batch size', type=int, metavar='INT') +@click.option('--cfg_map', help='Override config map', type=int, metavar='INT') # Discriminator augmentation. @click.option('--aug', help='Augmentation mode [default: ada]', type=click.Choice(['noaug', 'ada', 'fixed'])) From f83452fb6ba4ba49c538db1f37aa5a540c189241 Mon Sep 17 00:00:00 2001 From: Wok Date: Mon, 1 Feb 2021 13:52:50 +0100 Subject: [PATCH 12/13] Allow to enforce CIFAR-specific architecture tuning with --cifar_tune --- train.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 754a06003..1b8fc9741 100755 --- a/train.py +++ b/train.py @@ -44,6 +44,7 @@ def setup_training_loop_kwargs( # Base config. cfg = None, # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar' + cifar_tune = None, # Enforce CIFAR-specific architecture tuning: , default = False gamma = None, # Override R1 gamma: kimg = None, # Override training duration: batch = None, # Override batch size: @@ -194,7 +195,14 @@ def setup_training_loop_kwargs( args.ema_kimg = spec.ema args.ema_rampup = spec.ramp - if cfg == 'cifar': + if cifar_tune is None: + cifar_tune = False + else: + assert isinstance(cifar_tune, bool) + if cifar_tune: + desc += '-tuning' + + if cifar_tune or cfg == 'cifar': args.loss_kwargs.pl_weight = 0 # disable path length regularization args.loss_kwargs.style_mixing_prob = 0 # disable style mixing args.D_kwargs.architecture = 'orig' # disable residual skip connections @@ -422,6 +430,7 @@ def convert(self, value, param, ctx): # Base config. @click.option('--cfg', help='Base config [default: auto]', type=click.Choice(['auto', 'auto_norp', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar'])) +@click.option('--cifar_tune', help='Enforce CIFAR-specific architecture tuning (default: false)', type=bool, metavar='BOOL') @click.option('--gamma', help='Override R1 gamma', type=float) @click.option('--kimg', help='Override training duration', type=int, metavar='INT') @click.option('--batch', help='Override batch size', type=int, metavar='INT') From 2c507a6682a0e9eded649ea259f16ef85fabb0fb Mon Sep 17 00:00:00 2001 From: Wok Date: Mon, 1 Feb 2021 18:34:30 +0100 Subject: [PATCH 13/13] README: add changelog --- README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/README.md b/README.md index 880974e33..b29e59d48 100755 --- a/README.md +++ b/README.md @@ -1,3 +1,14 @@ +## Changelog + +- save output images as JPG, +- automatically resume from the latest `.pkl` file with the command-line argument `--resume=latest`, +- automatically set the resume value of `kimg`, +- automatically set the resume value of the augmentation strength, +- allow to **manually** set the resume value of the augmentation strength, +- add config `auto_norp` to replicate the `auto` config without EMA rampup, +- allow to override mapping net depth with the command-line argument `--cfg_map`, +- allow to enforce CIFAR-specific architecture tuning with the command-line argument `--cifar_tune`. + ## StyleGAN2-ADA — Official PyTorch implementation ![Teaser image](./docs/stylegan2-ada-teaser-1024x252.png)