Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training and sampling from scratch #104

Open
gogobd opened this issue Jun 6, 2020 · 10 comments
Open

Training and sampling from scratch #104

gogobd opened this issue Jun 6, 2020 · 10 comments

Comments

@gogobd
Copy link

gogobd commented Jun 6, 2020

I was trying to train and sample completely from scratch and ran into
#39
As I described there I added "labels=False" to the respective params, but I run into

Traceback (most recent call last):
  File "jukebox/sample.py", line 275, in <module>
    fire.Fire(run)
  File "/opt/miniconda/lib/python3.7/site-packages/fire/core.py", line 127, in Fire
    component_trace = _Fire(component, args, context, name)
  File "/opt/miniconda/lib/python3.7/site-packages/fire/core.py", line 366, in _Fire
    component, remaining_args)
  File "/opt/miniconda/lib/python3.7/site-packages/fire/core.py", line 542, in _CallCallable
    result = fn(*varargs, **kwargs)
  File "jukebox/sample.py", line 272, in run
    save_samples(model, device, hps, sample_hps)
  File "jukebox/sample.py", line 223, in save_samples
    labels = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in priors ]
  File "jukebox/sample.py", line 223, in <listcomp>
    labels = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in priors ]
  File "/opt/miniconda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 576, in __getattr__
    type(self).__name__, name))
AttributeError: 'SimplePrior' object has no attribute 'labeller'

What can I do to avoid this and get sampling to work without labels?

@gogobd
Copy link
Author

gogobd commented Jun 6, 2020

Additional information: I was following the steps from README.md and executed these steps to train and sample from scratch:

jukebox/train.py --hps=small_vqvae --name=small_vqvae --sample_length=262144 --bs=4 --nworkers=20 --audio_files_dir=./data/[………] --labels=False --train --aug_shift --aug_blend

train.py --hps=small_vqvae,small_prior,all_fp16,cpu_ema --name=small_prior --sample_length=2097152 --bs=4 --nworkers=10 --audio_files_dir=/data/[………] --labels=False --train --test --aug_shift --aug_blend --restore_vqvae=logs/small_vqvae/checkpoint_step_1000001.pth.tar --prior --levels=2 --level=1 --weight_decay=0.01 --save_iters=1000

train.py --hps=small_vqvae,small_upsampler,all_fp16,cpu_ema --name=small_upsampler --sample_length 262144 --bs 4 --nworkers 4 --audio_files_dir /data/[………] --labels False --train --test --aug_shift --aug_blend --restore_vqvae logs/small_vqvae/checkpoint_step_1000001.pth.tar --prior --levels 2 --level 0 --weight_decay 0.01 --save_iters 1000

@ObscuraDK
Copy link

Hi gogobd.
I have the same problem, so I was wondering if you were able to solve the problem.

@gogobd
Copy link
Author

gogobd commented Jun 10, 2020

Hi, @ObscuraDK! Unfortunately I wasn't able to solve it. I tried to add "--labels=False" as an option as suggested, but then sample.py fails, because it kinda expects these; I tried to add labels to the params but have no idea what they should look like. Unfortunately the issues here appear to be quite unattended for a while now. If you make some progress please let me know.

@heewooj
Copy link
Contributor

heewooj commented Jun 11, 2020

Thanks for bringing this up. --labels=False wasn't supported in our sampling code previously. We just pushed a fix with more instructions. Please pull and try again

@gogobd
Copy link
Author

gogobd commented Jun 12, 2020

I just saw that you also updated README.md which now has much better information on how to sample with or without labels! I had an issue with len(y_bins) but now that I followed your instructions there it's actually looking better! Thanks!

@ObscuraDK
Copy link

Hi @gogobd I have made a new vqvae,prior and upsampler. Well they havent been running for long, but I wanted to see if I could get through.

When I try to sample by shooting this command:
python jukebox/sample.py --model=piano --name=piano --levels=3 --n_samples=6 --sample_length_in_seconds=40 --total_sample_length_in_seconds=180 --sr=44100 --n_samples=6 --hop_fraction=0.5,0.5,0.125

I get this:
Using cuda True
{'name': 'piano', 'levels': 3, 'n_samples': 6, 'sample_length_in_seconds': 40, 'total_sample_length_in_seconds': 180, 'sr': 44100, 'hop_fraction': (0.5, 0.5, 0.125)}
Setting sample length to 881920 (i.e. 39.996371882086166 seconds) to be multiple of 256
Restored from /home/vertigo/jukebox/logs/piano_small_vqvae/checkpoint_step_20001.pth.tar
0: Loading vqvae in eval mode
Level:1, Cond downsample:None, Raw to tokens:256, Sample length:2097152
Restored from /home/vertigo/jukebox/logs/piano_small_prior/checkpoint_latest.pth.tar
0: Loading prior in eval mode
Conditioning on 1 above level(s)
Checkpointing convs
Checkpointing convs
Checkpointing convs
Level:0, Cond downsample:8, Raw to tokens:32, Sample length:262144
Restored from /home/vertigo/jukebox/logs/piano_small_upsampler/checkpoint_latest.pth.tar
0: Loading prior in eval mode
Traceback (most recent call last):
File "jukebox/sample.py", line 278, in
fire.Fire(run)
File "/home/vertigo/miniconda3/envs/jukebox/lib/python3.7/site-packages/fire/core.py", line 127, in Fire
component_trace = _Fire(component, args, context, name)
File "/home/vertigo/miniconda3/envs/jukebox/lib/python3.7/site-packages/fire/core.py", line 366, in _Fire
component, remaining_args)
File "/home/vertigo/miniconda3/envs/jukebox/lib/python3.7/site-packages/fire/core.py", line 542, in _CallCallable
result = fn(*varargs, **kwargs)
File "jukebox/sample.py", line 275, in run
save_samples(model, device, hps, sample_hps)
File "jukebox/sample.py", line 182, in save_samples
assert hps.sample_length//priors[-2].raw_to_tokens >= priors[-2].n_ctx, f"Upsampling needs atleast one ctx in get_z_conds. Please choose a longer sample length"
AssertionError: Upsampling needs atleast one ctx in get_z_conds. Please choose a longer sample length

@gogobd
Copy link
Author

gogobd commented Jun 12, 2020

@ObscuraDK I didn't get that error; are you training with labels? Maybe there's a sample_length in the hparams (file) that's too short / small?

@ObscuraDK
Copy link

ObscuraDK commented Jun 12, 2020

@gogobd Just started from scratch with the new upload to GitHub, and followed the guide, with no labels

@ObscuraDK
Copy link

ObscuraDK commented Jun 13, 2020

I have been looking into my AssertionError, and compared my dataset to the 1b_lyrics.
At this line where it all breaks:

File "jukebox/sample.py", line 182, in save_samples
assert hps.sample_length//priors[-2].raw_to_tokens >= priors[-2].n_ctx, f"Upsampling needs atleast one ctx in get_z_conds. Please choose a longer sample length"


When I run my own model:
python jukebox/sample.py --model=piano --name=sample_piano --levels=3 --n_samples=6 --sample_length_in_seconds=40 --total_sample_length_in_seconds=180 --sr=44100 --n_samples=6 --hop_fraction=0.5,0.5,0.125

I have these values before it throws the error:
raw tokens: 256
ctx: 8192
hps.samplelength: 881920

When I run the 1b_lyrics:
python jukebox/sample.py --model=1b_lyrics --name=sample_1b --levels=3 --n_samples=6 --sample_length_in_seconds=40 --total_sample_length_in_seconds=180 --sr=44100 --n_samples=6 --hop_fraction=0.5,0.5,0.125

I have these values at line 182 in sample.py
raw tokens: 32
ctx: 8192
hps.samplelength: 1763968

@heewooj any ideas to how I either extend the hps.samplelength or reduce the raw tokens, in my trained model?

@ObscuraDK
Copy link

Solved it by raising the sample_length_in_seconds to 96, now I get this error:

Using cuda True
{'name': 'piano', 'levels': 2, 'n_samples': 6, 'sample_length_in_seconds': 96, 'total_sample_length_in_seconds': 384, 'sr': 44100, 'hop_fraction': (0.5, 0.5, 0.125)}
Setting sample length to 2116608 (i.e. 95.9912925170068 seconds) to be multiple of 256
Restored from /home/vertigo/jukebox/logs/piano_small_vqvae/checkpoint_step_40001.pth.tar
0: Loading vqvae in eval mode
Level:1, Cond downsample:None, Raw to tokens:256, Sample length:2097152
Restored from /home/vertigo/jukebox/logs/piano_small_prior/checkpoint_latest.pth.tar
0: Loading prior in eval mode
Conditioning on 1 above level(s)
Checkpointing convs
Checkpointing convs
Checkpointing convs
Level:0, Cond downsample:8, Raw to tokens:32, Sample length:262144
Restored from /home/vertigo/jukebox/logs/piano_small_upsampler/checkpoint_latest.pth.tar
0: Loading prior in eval mode
raw tokens: 256
ctx: 8192
hps.samplelength: 2116608
Sampling level 1
Sampling 8192 tokens for [0,8192]. Conditioning on 0 tokens
Traceback (most recent call last):
File "jukebox/sample.py", line 280, in
fire.Fire(run)
File "/home/vertigo/miniconda3/envs/jukebox/lib/python3.7/site-packages/fire/core.py", line 127, in Fire
component_trace = _Fire(component, args, context, name)
File "/home/vertigo/miniconda3/envs/jukebox/lib/python3.7/site-packages/fire/core.py", line 366, in _Fire
component, remaining_args)
File "/home/vertigo/miniconda3/envs/jukebox/lib/python3.7/site-packages/fire/core.py", line 542, in _CallCallable
result = fn(*varargs, **kwargs)
File "jukebox/sample.py", line 277, in run
save_samples(model, device, hps, sample_hps)
File "jukebox/sample.py", line 245, in save_samples
ancestral_sample(labels, sampling_kwargs, priors, hps)
File "jukebox/sample.py", line 126, in ancestral_sample
zs = _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps)
File "jukebox/sample.py", line 102, in _sample
zs = sample_level(zs, labels[level], sampling_kwargs[level], level, prior, total_length, hop_length, hps)
File "jukebox/sample.py", line 85, in sample_level
zs = sample_single_window(zs, labels, sampling_kwargs, level, prior, start, hps)
File "jukebox/sample.py", line 53, in sample_single_window
z_conds = prior.get_z_conds(zs, start, end)
File "/home/vertigo/jukebox/jukebox/prior/prior.py", line 162, in get_z_conds
assert z_cond.shape[1] == self.n_ctx//self.cond_downsample
AssertionError

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants