Skip to content

Commit

Permalink
Merge pull request #83 from openai/multiple_vqvae
Browse files Browse the repository at this point in the history
Multiple vqvae
  • Loading branch information
prafullasd authored May 18, 2020
2 parents 3f54599 + d3754bd commit f326b52
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 11 deletions.
4 changes: 2 additions & 2 deletions jukebox/data/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def set_y_lyric_tokens(self, ys, labels):
def describe_label(self, y):
assert y.shape == self.label_shape, f"Expected {self.label_shape}, got {y.shape}"
y = np.array(y).tolist()
total_length, offset, length, artist_id, *genre_ids = y[:-self.n_tokens]
tokens = y[-self.n_tokens:]
total_length, offset, length, artist_id, *genre_ids = y[:4 + self.max_genre_words]
tokens = y[4 + self.max_genre_words:]
artist = self.ag_processor.get_artist(artist_id)
genre = self.ag_processor.get_genre(genre_ids)
lyrics = self.text_processor.textise(tokens)
Expand Down
4 changes: 0 additions & 4 deletions jukebox/hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def setup_hparams(hparam_set_names, kwargs):
prime_loss_fraction=0.0,
fp16_params=False,
)
upsamplers.update(vqvae)
upsamplers.update(labels)

upsampler_level_0 = Hyperparams(
Expand Down Expand Up @@ -119,7 +118,6 @@ def setup_hparams(hparam_set_names, kwargs):
restore_prior='gs://jukebox-assets/models/5b/prior_level_2.pth.tar',
fp16_params=True,
)
prior_5b.update(vqvae)
prior_5b.update(labels)
HPARAMS_REGISTRY["prior_5b"] = prior_5b

Expand Down Expand Up @@ -152,7 +150,6 @@ def setup_hparams(hparam_set_names, kwargs):
alignment_layer=68,
alignment_head=2,
)
prior_5b_lyrics.update(vqvae)
prior_5b_lyrics.update(labels)
HPARAMS_REGISTRY["prior_5b_lyrics"] = prior_5b_lyrics

Expand Down Expand Up @@ -185,7 +182,6 @@ def setup_hparams(hparam_set_names, kwargs):
alignment_layer=63,
alignment_head=0,
)
prior_1b_lyrics.update(vqvae)
prior_1b_lyrics.update(labels_v3)
HPARAMS_REGISTRY["prior_1b_lyrics"] = prior_1b_lyrics

Expand Down
8 changes: 4 additions & 4 deletions jukebox/make_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def make_vqvae(hps, device='cuda'):
def make_prior(hps, vqvae, device='cuda'):
from jukebox.prior.prior import SimplePrior

prior_kwargs = dict(input_shape=(hps.n_ctx,), bins=hps.l_bins,
prior_kwargs = dict(input_shape=(hps.n_ctx,), bins=vqvae.l_bins,
width=hps.prior_width, depth=hps.prior_depth, heads=hps.heads,
attn_order=hps.attn_order, blocks=hps.blocks, spread=hps.spread,
attn_dropout=hps.attn_dropout, resid_dropout=hps.resid_dropout, emb_dropout=hps.emb_dropout,
Expand Down Expand Up @@ -142,12 +142,12 @@ def make_prior(hps, vqvae, device='cuda'):
z_shapes = [rescale(z_shape) for z_shape in vqvae.z_shapes]

prior = SimplePrior(z_shapes=z_shapes,
l_bins=hps.l_bins,
l_bins=vqvae.l_bins,
encoder=vqvae.encode,
decoder=vqvae.decode,
level=hps.level,
downs_t=hps.downs_t,
strides_t=hps.strides_t,
downs_t=vqvae.downs_t,
strides_t=vqvae.strides_t,
labels=hps.labels,
prior_kwargs=prior_kwargs,
x_cond_kwargs=x_cond_kwargs,
Expand Down
2 changes: 1 addition & 1 deletion jukebox/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def evaluate(model, orig_model, logger, metrics, data_processor, hps):
_print_keys = dict(l="loss", rl="recons_loss", sl="spectral_loss")

with t.no_grad():
for i, x in logger.get_range(data_processor.train_loader):
for i, x in logger.get_range(data_processor.test_loader):
if isinstance(x, (tuple, list)):
x, y = x
else:
Expand Down
2 changes: 2 additions & 0 deletions jukebox/vqvae/vqvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def _block_kwargs(level):
else:
self.bottleneck = NoBottleneck(levels)

self.downs_t = downs_t
self.strides_t = strides_t
self.l_bins = l_bins
self.commit = commit
self.spectral = spectral
Expand Down

0 comments on commit f326b52

Please sign in to comment.