Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Error(s) in loading state_dict for SimplePrior: missing key(s) in state_dict #101

Open
pkmital opened this issue Jun 1, 2020 · 1 comment

Comments

@pkmital
Copy link

pkmital commented Jun 1, 2020

Thank you for the great release! I have followed the README instructions for training my own small_prior and am working on running the sample.py script, but getting the error:

$ python jukebox/sample.py --model=custom --name=custom --levels=3 --n_samples=6 --sample_length_in_seconds=20 --total_sample_length_in_seconds=180 --sr=44100 --n_samples=6 --hop_fraction=0.5,0.5,0.125
Using cuda True
{'name': 'custom', 'levels': 3, 'n_samples': 6, 'sample_length_in_seconds': 20, 'total_sample_length_in_seconds': 180, 'sr': 44100, 'hop_fraction': (0.5, 0.5, 0.125)}
Setting sample length to 881920 (i.e. 19.998185941043083 seconds) to be multiple of 128
Downloading from gce
Restored from /home/pmital/.cache/jukebox-assets/models/5b/vqvae.pth.tar
0: Loading vqvae in eval mode
Using apex FusedLayerNorm
Conditioning on 1 above level(s)
Checkpointing convs
Checkpointing convs
Loading artist IDs from /home/pmital/dev/jukebox/jukebox/data/ids/v2_artist_ids.txt
Loading artist IDs from /home/pmital/dev/jukebox/jukebox/data/ids/v2_genre_ids.txt
Level:0, Cond downsample:4, Raw to tokens:8, Sample length:65536
Downloading from gce
Restored from /home/pmital/.cache/jukebox-assets/models/5b/prior_level_0.pth.tar
0: Loading prior in eval mode
Conditioning on 1 above level(s)
Checkpointing convs
Checkpointing convs
Loading artist IDs from /home/pmital/dev/jukebox/jukebox/data/ids/v2_artist_ids.txt
Loading artist IDs from /home/pmital/dev/jukebox/jukebox/data/ids/v2_genre_ids.txt
Level:1, Cond downsample:4, Raw to tokens:32, Sample length:262144
Downloading from gce
Restored from /home/pmital/.cache/jukebox-assets/models/5b/prior_level_1.pth.tar
0: Loading prior in eval mode
Conditioning on 1 above level(s)
Level:0, Cond downsample:2, Raw to tokens:2, Sample length:16384
Restored from /home/pmital/dev/jukebox/logs/pretrained_vqvae_small_prior/checkpoint_latest.pth.tar
Traceback (most recent call last):
  File "jukebox/sample.py", line 275, in <module>
    fire.Fire(run)
  File "/home/pmital/anaconda3/envs/jukebox/lib/python3.7/site-packages/fire/core.py", line 127, in Fire
    component_trace = _Fire(component, args, context, name)
  File "/home/pmital/anaconda3/envs/jukebox/lib/python3.7/site-packages/fire/core.py", line 366, in _Fire
    component, remaining_args)
  File "/home/pmital/anaconda3/envs/jukebox/lib/python3.7/site-packages/fire/core.py", line 542, in _CallCallable
    result = fn(*varargs, **kwargs)
  File "jukebox/sample.py", line 272, in run
    save_samples(model, device, hps, sample_hps)
  File "jukebox/sample.py", line 177, in save_samples
    vqvae, priors = make_model(model, device, hps)
  File "/home/pmital/dev/jukebox/jukebox/make_models.py", line 187, in make_model
    priors = [make_prior(setup_hparams(priors[level], dict()), vqvae, 'cpu') for level in levels]
  File "/home/pmital/dev/jukebox/jukebox/make_models.py", line 187, in <listcomp>
    priors = [make_prior(setup_hparams(priors[level], dict()), vqvae, 'cpu') for level in levels]
  File "/home/pmital/dev/jukebox/jukebox/make_models.py", line 171, in make_prior
    restore(hps, prior, hps.restore_prior)
  File "/home/pmital/dev/jukebox/jukebox/make_models.py", line 61, in restore
    model.load_state_dict(checkpoint['model'])
  File "/home/pmital/anaconda3/envs/jukebox/lib/python3.7/site-packages/torch/nn/modules/module.py", line 777, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for SimplePrior:
        Missing key(s) in state_dict: "conditioner_blocks.0.x_emb.weight", "conditioner_blocks.0.cond.model.0.weight", "conditioner_blocks.0.cond.model.0.bias", "conditioner_blocks.0.cond.model.1.0.model.0.model.1.weight", "conditioner_blocks.0.cond.model.1.0.model.0.model.1.bias", "conditioner_blocks.0.cond.model.1.0.model.0.model.3.weight", "conditioner_blocks.0.cond.model.1.0.model.0.model.3.bias", "conditioner_blocks.0.cond.model.1.0.model.1.model.1.weight", "conditioner_blocks.0.cond.model.1.0.model.1.model.1.bias", "conditioner_blocks.0.cond.model.1.0.model.1.model.3.weight", "conditioner_blocks.0.cond.model.1.0.model.1.model.3.bias", "conditioner_blocks.0.cond.model.1.0.model.2.model.1.weight", "conditioner_blocks.0.cond.model.1.0.model.2.model.1.bias", "conditioner_blocks.0.cond.model.1.0.model.2.model.3.weight", "conditioner_blocks.0.cond.model.1.0.model.2.model.3.bias", "conditioner_blocks.0.cond.model.1.1.weight", "conditioner_blocks.0.cond.model.1.1.bias", "conditioner_blocks.0.ln.weight", "conditioner_blocks.0.ln.bias".

I ran training using the command:

python jukebox/train.py --hps=vqvae,small_prior,all_fp16,cpu_ema --name=pretrained_vqvae_small_prior --sample_length=1048576 --bs=4 --aug_shift --aug_blend --audio_files_dir=/mnt/all_silent_films_v1_wavs \
--labels=False --train --test --prior --levels=3 --level=2 --weight_decay=0.01 --save_iters=1000
Using apex fused_adam_cuda
Using cuda True
0: Found 1643 files. Getting durations
0: self.sr=44100, min: 24, max: inf
0: Keeping 1607 of 1643 files
{'l2': 1.1213942968362299e-10, 'l1': 1.1928640560654458e-05, 'spec': 0.8459351658821106}
Creating Data Loader
0: Train 100984 samples. Test 11221 samples
0: Train sampler: <torch.utils.data.distributed.DistributedSampler object at 0x7fd6708e7790>
0: Train loader: 25246                                                                                                                                                                                                                        Downloading from gce                                                                                                                                                                                                                          Restored from /home/pmital/.cache/jukebox-assets/models/5b/vqvae.pth.tar                                                                                                                                                                      0: Loading vqvae in eval mode                                                                                                                                                                                                                 Parameters VQVAE:0                                                                                                                                                                                                                            Using apex FusedLayerNorm                                                                                                                                                                                                                     Level:2, Cond downsample:None, Raw to tokens:128, Sample length:1048576                                                                                                                                                                       0: Converting to fp16 params
0: Loading prior in train mode                                                                                                                                                                                                                Parameters Prior:161862656                                                                                                                                                                                                                    {'dynamic': True, 'loss_scale': 65536.0, 'max_loss_scale': 16777216.0, 'scale_factor': 1.0027764359010778, 'scale_window': 1, 'unskipped': 0, 'overflow': False}                                                                              Using CPU EMA                                                                                                                                                                                                                                 Logging to logs/pretrained_vqvae_small_prior                                                                                                                                                                                                  0/25246 [00:00<?, ?it/s]Ancestral sampling 4 samples with temp=1.0, top_k=0, top_p=0.0
8192/8192 [05:14<00:00, 26.09it/s]                                                                                                                                                                                                            warning: audio amplitude out of range, auto clipped.                                                                                                                                                                                          warning: audio amplitude out of range, auto clipped.                                                                                                                                                                                          warning: audio amplitude out of range, auto clipped.                                                                                                                                                                                          warning: audio amplitude out of range, auto clipped.                                                                                                                                                                                          Logging train inputs/ouputs
805/25246 [1:42:42<48:58:42,  7.21s/it, bpd=5.4, g_l=5.4, gn=3.76, l=5.4, p_l=0]
Overflow in backward. Loss 5.543792724609375, grad norm inf, lgscale 19.22000000000012, new lgscale 18.22000000000012
1001/25246 [2:06:13<49:12:14,  7.31s/it, bpd=3.57, g_l=3.57, gn=6.05, l=3.57, p_l=0]Logging train inputs/ouputs
1095/25246 [2:17:41<48:15:58,  7.19s/it, bpd=3.59, g_l=3.59, gn=24.4, l=3.59, p_l=0]
Overflow in backward. Loss 5.444324016571045, grad norm inf, lgscale 19.37600000000016, new lgscale 18.37600000000016
1355/25246 [2:48:53<47:46:12,  7.20s/it, bpd=1.17, g_l=1.17, gn=2.49, l=1.17, p_l=0]                                                                                                                                                          Overflow in backward. Loss 2.776571750640869, grad norm inf, lgscale 19.412000000000198, new lgscale 18.412000000000198
...
...
...
Overflow in backward. Loss 4.581544876098633, grad norm nan, lgscale 17.20000000001162, new lgscale 16.20000000001162
4727/25246 [9:28:28<41:00:18,  7.19s/it, bpd=3.04, g_l=3.04, gn=1.84, l=3.04, p_l=0]                                                                                                                                                          Overflow in backward. Loss 3.136387825012207, grad norm nan, lgscale 16.580000000011633, new lgscale 15.580000000011633                                                                                                                       5377/25246 [10:46:30<39:43:52,  7.20s/it, bpd=3.44, g_l=3.44, gn=2.82, l=3.44, p_l=0]                                                                                                                                                         Overflow in backward. Loss 3.6067352294921875, grad norm inf, lgscale 18.17600000001173, new lgscale 17.17600000001173                                                                                                                        5584/25246 [11:11:21<39:57:19,  7.32s/it, bpd=1.81, g_l=1.81, gn=2.03, l=1.81, p_l=0]Logging train inputs/ouputs                                                                                                                              5603/25246 [11:13:49<39:18:47,  7.21s/it, bpd=3.65, g_l=3.65, gn=2.68, l=3.65, p_l=0]
Overflow in backward. Loss 3.069653034210205, grad norm inf, lgscale 18.076000000011764, new lgscale 17.076000000011764

updated hparams.py to include:

 small_prior = Hyperparams(
     n_ctx=8192,
     prior_width=1024,
     prior_depth=48,
     heads=1,
     c_res=1,
     attn_order=2,
     blocks=64,
     init_scale=0.7,
     labels=False,
     # ema=True,
     # cpu_ema=True,
     # cpu_ema_freq=100,
     # ema_fused=False,
     # fp16=True,
     # fp16_params=True,
     # fp16_opt=True,
     # fp16_scale_window=250,
     l_bins= 2048,
     # y_bins=(0,0), # Set this to (genres, artists) for your dataset
     restore_prior="/home/pmital/dev/jukebox/logs/pretrained_vqvae_small_prior/checkpoint_latest.pth.tar",
 )

and created a model in make_models.py like so:

 MODELS = {
     '5b': ("vqvae", "upsampler_level_0", "upsampler_level_1", "prior_5b"),
     '5b_lyrics': ("vqvae", "upsampler_level_0", "upsampler_level_1", "prior_5b_lyrics"),
     '1b_lyrics': ("vqvae", "upsampler_level_0", "upsampler_level_1", "prior_1b_lyrics"),
     'custom': ("vqvae", "upsampler_level_0", "upsampler_level_1", "small_prior"),
     #'your_model': ("you_vqvae_here", "your_upsampler_here", ..., "you_top_level_prior_here")
 }

I'm on commit 3f54599, remove continuations
Ubuntu 18.04.3 LTS w/ 30 GB RAM
P100 GPU w/ 16 GB RAM, Driver Version: 440.33.01
conda 4.8.3
Python 3.7.5

Thank you for any help!

@prafullasd
Copy link
Collaborator

prafullasd commented Jun 3, 2020

Add level=2 to the hparams dictionary for small_prior

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants