diff --git a/README.md b/README.md index 880974e33..b29e59d48 100755 --- a/README.md +++ b/README.md @@ -1,3 +1,14 @@ +## Changelog + +- save output images as JPG, +- automatically resume from the latest `.pkl` file with the command-line argument `--resume=latest`, +- automatically set the resume value of `kimg`, +- automatically set the resume value of the augmentation strength, +- allow to **manually** set the resume value of the augmentation strength, +- add config `auto_norp` to replicate the `auto` config without EMA rampup, +- allow to override mapping net depth with the command-line argument `--cfg_map`, +- allow to enforce CIFAR-specific architecture tuning with the command-line argument `--cifar_tune`. + ## StyleGAN2-ADA — Official PyTorch implementation ![Teaser image](./docs/stylegan2-ada-teaser-1024x252.png) diff --git a/torch_utils/ops/conv2d_gradfix.py b/torch_utils/ops/conv2d_gradfix.py index e95e10d0b..bb70f1ea6 100755 --- a/torch_utils/ops/conv2d_gradfix.py +++ b/torch_utils/ops/conv2d_gradfix.py @@ -12,6 +12,7 @@ import warnings import contextlib import torch +from pkg_resources import parse_version # pylint: disable=redefined-builtin # pylint: disable=arguments-differ @@ -21,6 +22,7 @@ enabled = False # Enable the custom op by setting this to true. weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. +_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 @contextlib.contextmanager def no_weight_gradients(): @@ -48,6 +50,9 @@ def _should_use_custom_op(input): assert isinstance(input, torch.Tensor) if (not enabled) or (not torch.backends.cudnn.enabled): return False + if _use_pytorch_1_11_api: + # The work-around code doesn't work on PyTorch 1.11.0 onwards + return False if input.device.type != 'cuda': return False if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): diff --git a/torch_utils/ops/grid_sample_gradfix.py b/torch_utils/ops/grid_sample_gradfix.py index ca6b3413e..98b5b97b6 100755 --- a/torch_utils/ops/grid_sample_gradfix.py +++ b/torch_utils/ops/grid_sample_gradfix.py @@ -13,6 +13,7 @@ import warnings import torch +from pkg_resources import parse_version # pylint: disable=redefined-builtin # pylint: disable=arguments-differ @@ -21,6 +22,8 @@ #---------------------------------------------------------------------------- enabled = False # Enable the custom op by setting this to true. +_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 +_use_pytorch_1_12_api = parse_version(torch.__version__) >= parse_version('1.12.0a') # Allow prerelease builds of 1.12 #---------------------------------------------------------------------------- @@ -34,7 +37,7 @@ def grid_sample(input, grid): def _should_use_custom_op(): if not enabled: return False - if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): + if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9', '1.1', '2']): return True warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') return False @@ -62,7 +65,13 @@ class _GridSample2dBackward(torch.autograd.Function): @staticmethod def forward(ctx, grad_output, input, grid): op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') - grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) + if _use_pytorch_1_12_api: + op = op[0] + if _use_pytorch_1_11_api: + output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask) + else: + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) ctx.save_for_backward(grid) return grad_input, grad_grid diff --git a/train.py b/train.py index 8d81b3f18..1b8fc9741 100755 --- a/train.py +++ b/train.py @@ -44,9 +44,11 @@ def setup_training_loop_kwargs( # Base config. cfg = None, # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar' + cifar_tune = None, # Enforce CIFAR-specific architecture tuning: , default = False gamma = None, # Override R1 gamma: kimg = None, # Override training duration: batch = None, # Override batch size: + cfg_map = None, # Override config map: , default = depends on cfg # Discriminator augmentation. aug = None, # Augmentation mode: 'ada' (default), 'noaug', 'fixed' @@ -153,6 +155,7 @@ def setup_training_loop_kwargs( cfg_specs = { 'auto': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2), # Populated dynamically based on resolution and GPU count. + 'auto_norp': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=None, map=2), 'stylegan2': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # Uses mixed-precision, unlike the original StyleGAN2. 'paper256': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=0.5, lrate=0.0025, gamma=1, ema=20, ramp=None, map=8), 'paper512': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=1, lrate=0.0025, gamma=0.5, ema=20, ramp=None, map=8), @@ -162,7 +165,7 @@ def setup_training_loop_kwargs( assert cfg in cfg_specs spec = dnnlib.EasyDict(cfg_specs[cfg]) - if cfg == 'auto': + if cfg.startswith('auto'): desc += f'{gpus:d}' spec.ref_gpus = gpus res = args.training_set_kwargs.resolution @@ -192,7 +195,14 @@ def setup_training_loop_kwargs( args.ema_kimg = spec.ema args.ema_rampup = spec.ramp - if cfg == 'cifar': + if cifar_tune is None: + cifar_tune = False + else: + assert isinstance(cifar_tune, bool) + if cifar_tune: + desc += '-tuning' + + if cifar_tune or cfg == 'cifar': args.loss_kwargs.pl_weight = 0 # disable path length regularization args.loss_kwargs.style_mixing_prob = 0 # disable style mixing args.D_kwargs.architecture = 'orig' # disable residual skip connections @@ -219,6 +229,12 @@ def setup_training_loop_kwargs( args.batch_size = batch args.batch_gpu = batch // gpus + if cfg_map is not None: + assert isinstance(cfg_map, int) + if not cfg_map >= 1: + raise UserError('--cfg_map must be at least 1') + args.G_kwargs.mapping_kwargs.num_layers = cfg_map + # --------------------------------------------------- # Discriminator augmentation: aug, p, target, augpipe # --------------------------------------------------- @@ -244,8 +260,8 @@ def setup_training_loop_kwargs( if p is not None: assert isinstance(p, float) - if aug != 'fixed': - raise UserError('--p can only be specified with --aug=fixed') + if resume != 'latest' and aug != 'fixed': + raise UserError('--p can only be specified with --resume=latest or --aug=fixed') if not 0 <= p <= 1: raise UserError('--p must be between 0 and 1') desc += f'-p{p:g}' @@ -413,10 +429,12 @@ def convert(self, value, param, ctx): @click.option('--mirror', help='Enable dataset x-flips [default: false]', type=bool, metavar='BOOL') # Base config. -@click.option('--cfg', help='Base config [default: auto]', type=click.Choice(['auto', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar'])) +@click.option('--cfg', help='Base config [default: auto]', type=click.Choice(['auto', 'auto_norp', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar'])) +@click.option('--cifar_tune', help='Enforce CIFAR-specific architecture tuning (default: false)', type=bool, metavar='BOOL') @click.option('--gamma', help='Override R1 gamma', type=float) @click.option('--kimg', help='Override training duration', type=int, metavar='INT') @click.option('--batch', help='Override batch size', type=int, metavar='INT') +@click.option('--cfg_map', help='Override config map', type=int, metavar='INT') # Discriminator augmentation. @click.option('--aug', help='Augmentation mode [default: ada]', type=click.Choice(['noaug', 'ada', 'fixed'])) diff --git a/training/misc.py b/training/misc.py new file mode 100644 index 000000000..d40cc2069 --- /dev/null +++ b/training/misc.py @@ -0,0 +1,66 @@ +import glob +import os +import re + +from pathlib import Path + +def get_parent_dir(run_dir): + out_dir = Path(run_dir).parent + + return out_dir + +def locate_latest_pkl(out_dir): + all_pickle_names = sorted(glob.glob(os.path.join(out_dir, '0*', 'network-*.pkl'))) + + try: + latest_pickle_name = all_pickle_names[-1] + except IndexError: + latest_pickle_name = None + + return latest_pickle_name + +def parse_kimg_from_network_name(network_pickle_name): + + if network_pickle_name is not None: + resume_run_id = os.path.basename(os.path.dirname(network_pickle_name)) + RE_KIMG = re.compile('network-snapshot-(\d+).pkl') + try: + kimg = int(RE_KIMG.match(os.path.basename(network_pickle_name)).group(1)) + except AttributeError: + kimg = 0.0 + else: + kimg = 0.0 + + return float(kimg) + + +def parse_augment_p_from_log(network_pickle_name): + + if network_pickle_name is not None: + network_folder_name = os.path.dirname(network_pickle_name) + log_file_name = network_folder_name + "/log.txt" + + try: + with open(log_file_name, "r") as f: + # Tokenize each line starting with the word 'tick' + lines = [ + l.strip().split() for l in f.readlines() if l.startswith("tick") + ] + except FileNotFoundError: + lines = [] + + # Extract the last token of each line for which the second to last token is 'augment' + values = [ + tokens[-1] + for tokens in lines + if len(tokens) > 1 and tokens[-2] == "augment" + ] + + if len(values)>0: + augment_p = float(values[-1]) + else: + augment_p = 0.0 + else: + augment_p = 0.0 + + return float(augment_p) diff --git a/training/training_loop.py b/training/training_loop.py index 14836ad2e..5c2a17892 100755 --- a/training/training_loop.py +++ b/training/training_loop.py @@ -23,6 +23,7 @@ import legacy from metrics import metric_main +from training import misc as tmisc #---------------------------------------------------------------------------- @@ -152,6 +153,20 @@ def training_loop( G_ema = copy.deepcopy(G).eval() # Resume from existing pickle. + if resume_pkl == 'latest': + out_dir = tmisc.get_parent_dir(run_dir) + resume_pkl = tmisc.locate_latest_pkl(out_dir) + + resume_kimg = tmisc.parse_kimg_from_network_name(resume_pkl) + if resume_kimg > 0: + print(f'Resuming from kimg = {resume_kimg}') + + if ada_target is not None and augment_p == 0: + # Overwrite augment_p only if the augmentation probability is not fixed by the user + augment_p = tmisc.parse_augment_p_from_log(resume_pkl) + if augment_p > 0: + print(f'Resuming with augment_p = {augment_p}') + if (resume_pkl is not None) and (rank == 0): print(f'Resuming from "{resume_pkl}"') with dnnlib.util.open_url(resume_pkl) as f: @@ -220,11 +235,11 @@ def training_loop( if rank == 0: print('Exporting sample images...') grid_size, images, labels = setup_snapshot_image_grid(training_set=training_set) - save_image_grid(images, os.path.join(run_dir, 'reals.png'), drange=[0,255], grid_size=grid_size) + save_image_grid(images, os.path.join(run_dir, 'reals.jpg'), drange=[0,255], grid_size=grid_size) grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu) grid_c = torch.from_numpy(labels).to(device).split(batch_gpu) images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy() - save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size) + save_image_grid(images, os.path.join(run_dir, 'fakes_init.jpg'), drange=[-1,1], grid_size=grid_size) # Initialize logs. if rank == 0: @@ -245,14 +260,14 @@ def training_loop( if rank == 0: print(f'Training for {total_kimg} kimg...') print() - cur_nimg = 0 + cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() maintenance_time = tick_start_time - start_time batch_idx = 0 if progress_fn is not None: - progress_fn(0, total_kimg) + progress_fn(int(resume_kimg), total_kimg) while True: # Fetch training data. @@ -347,7 +362,7 @@ def training_loop( # Save image snapshot. if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0): images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy() - save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1,1], grid_size=grid_size) + save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.jpg'), drange=[-1,1], grid_size=grid_size) # Save network snapshot. snapshot_pkl = None