diff --git a/dac/nn/quantize.py b/dac/nn/quantize.py index b17ff4a..f3ef673 100644 --- a/dac/nn/quantize.py +++ b/dac/nn/quantize.py @@ -137,21 +137,18 @@ def forward(self, z, n_quantizers: int = None): when in training mode, and a random number of quantizers is used. Returns ------- - dict - A dictionary with the following keys: - - "z" : Tensor[B x D x T] - Quantized continuous representation of input - "codes" : Tensor[B x N x T] - Codebook indices for each codebook - (quantized discrete representation of input) - "latents" : Tensor[B x N*D x T] - Projected latents (continuous representation of input before quantization) - "vq/commitment_loss" : Tensor[1] - Commitment loss to train encoder to predict vectors closer to codebook - entries - "vq/codebook_loss" : Tensor[1] - Codebook loss to update the codebook + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook """ z_q = 0 residual = z @@ -258,5 +255,5 @@ def from_latents(self, latents: torch.Tensor): if __name__ == "__main__": rvq = ResidualVectorQuantize(quantizer_dropout=True) x = torch.randn(16, 512, 80) - y = rvq(x) - print(y["latents"].shape) + z_q, codes, latents, commitment_loss, codebook_loss = rvq(x) + print(latents.shape)