Skip to content

Commit 7277264

Browse files
committed
Allow to enforce CIFAR-specific architecture tuning with --cifar_tune
1 parent 7b64651 commit 7277264

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

train.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def setup_training_loop_kwargs(
4444

4545
# Base config.
4646
cfg = None, # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar'
47+
cifar_tune = None, # Enforce CIFAR-specific architecture tuning: <bool>, default = False
4748
gamma = None, # Override R1 gamma: <float>
4849
kimg = None, # Override training duration: <int>
4950
batch = None, # Override batch size: <int>
@@ -194,7 +195,14 @@ def setup_training_loop_kwargs(
194195
args.ema_kimg = spec.ema
195196
args.ema_rampup = spec.ramp
196197

197-
if cfg == 'cifar':
198+
if cifar_tune is None:
199+
cifar_tune = False
200+
else:
201+
assert isinstance(cifar_tune, bool)
202+
if cifar_tune:
203+
desc += '-tuning'
204+
205+
if cifar_tune or cfg == 'cifar':
198206
args.loss_kwargs.pl_weight = 0 # disable path length regularization
199207
args.loss_kwargs.style_mixing_prob = 0 # disable style mixing
200208
args.D_kwargs.architecture = 'orig' # disable residual skip connections
@@ -422,6 +430,7 @@ def convert(self, value, param, ctx):
422430

423431
# Base config.
424432
@click.option('--cfg', help='Base config [default: auto]', type=click.Choice(['auto', 'auto_norp', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar']))
433+
@click.option('--cifar_tune', help='Enforce CIFAR-specific architecture tuning (default: false)', type=bool, metavar='BOOL')
425434
@click.option('--gamma', help='Override R1 gamma', type=float)
426435
@click.option('--kimg', help='Override training duration', type=int, metavar='INT')
427436
@click.option('--batch', help='Override batch size', type=int, metavar='INT')

0 commit comments

Comments
 (0)