Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training issue: AssertionError: Expected torch.Size #93

Closed
ObscuraDK opened this issue May 26, 2020 · 2 comments
Closed

Training issue: AssertionError: Expected torch.Size #93

ObscuraDK opened this issue May 26, 2020 · 2 comments

Comments

@ObscuraDK
Copy link

Hi there.

I am trying to train a vqvae, but are ending with the following error message.
I have added the while statement to audio_utils.py,as decribed in #59 .

0/96 [00:00<?, ?it/s]
Traceback (most recent call last):
File "jukebox/train.py", line 342, in
fire.Fire(run)
File "/home/vertigo/miniconda3/envs/jukebox/lib/python3.7/site-packages/fire/core.py", line 127, in Fire
component_trace = _Fire(component, args, context, name)
File "/home/vertigo/miniconda3/envs/jukebox/lib/python3.7/site-packages/fire/core.py", line 366, in _Fire
component, remaining_args)
File "/home/vertigo/miniconda3/envs/jukebox/lib/python3.7/site-packages/fire/core.py", line 542, in _CallCallable
result = fn(*varargs, **kwargs)
File "jukebox/train.py", line 325, in run
train_metrics = train(distributed_model, model, opt, shd, scalar, ema, logger, metrics, data_processor, hps)
File "jukebox/train.py", line 227, in train
x_out, loss, _metrics = model(x, **forw_kwargs)
File "/home/vertigo/miniconda3/envs/jukebox/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in call
result = self.forward(*input, **kwargs)
File "/home/vertigo/miniconda3/envs/jukebox/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 376, in forward
output = self.module(*inputs[0], **kwargs[0])
File "/home/vertigo/miniconda3/envs/jukebox/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in call
result = self.forward(*input, **kwargs)
File "/home/vertigo/jukebox/jukebox/vqvae/vqvae.py", line 168, in forward
assert_shape(x_out, x_in.shape)
File "/home/vertigo/jukebox/jukebox/utils/torch_utils.py", line 25, in assert_shape
assert x.shape == exp_shape, f"Expected {exp_shape} got {x.shape}"
AssertionError: Expected torch.Size([4, 1, 130976]) got torch.Size([4, 1, 130816])

@ObscuraDK
Copy link
Author

I have figured out that its line 225 in train.py which fails when it runs this:
x_out, loss, _metrics = model(x, **forw_kwargs)

@ObscuraDK
Copy link
Author

Solved: typo in setting sample length, and fail in multiple GPU setup.

I am running with two 1080 ti

I fired this and restart:
sudo nvidia-xconfig -sli=off -multigpu=off

And fired this:
mpiexec -n 2 python jukebox/train.py --hps=small_vqvae --name=small_vqvae --sample_length=131072 --bs=4
--audio_files_dir={audio_files_dir} --labels=False --train --aug_shift --aug_blend

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant