diff --git a/src/pythae/models/vq_vae/vq_vae_utils.py b/src/pythae/models/vq_vae/vq_vae_utils.py index 6a4b6662..54081458 100644 --- a/src/pythae/models/vq_vae/vq_vae_utils.py +++ b/src/pythae/models/vq_vae/vq_vae_utils.py @@ -50,14 +50,14 @@ def forward(self, z: torch.Tensor, uses_ddp: bool = False): commitment_loss = F.mse_loss( quantized.detach().reshape(-1, self.embedding_dim), z.reshape(-1, self.embedding_dim), - reduction="mean", + reduction="sum", ) embedding_loss = F.mse_loss( quantized.reshape(-1, self.embedding_dim), z.detach().reshape(-1, self.embedding_dim), - reduction="mean", - ).mean(dim=-1) + reduction="sum", + ) quantized = z + (quantized - z).detach() @@ -147,7 +147,7 @@ def forward(self, z: torch.Tensor, uses_ddp: bool = False): commitment_loss = F.mse_loss( quantized.detach().reshape(-1, self.embedding_dim), z.reshape(-1, self.embedding_dim), - reduction="mean", + reduction="sum", ) quantized = z + (quantized - z).detach()